# Copyright 2017-2022 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes for the ImageAssembler."""
from pyspark import keyword_only
from pyspark.ml.param import TypeConverters, Params, Param
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import regexp_replace, col
from sparknlp.common import AnnotatorType
from sparknlp.internal import AnnotatorTransformer
[docs]class ImageAssembler(AnnotatorTransformer):
"""Prepares images read by Spark into a format that is processable by Spark NLP.
This component is needed to process images.
====================== ======================
Input Annotation types Output Annotation type
====================== ======================
``NONE`` ``IMAGE``
====================== ======================
Parameters
----------
inputCol
Input column name
outputCol
Output column name
Examples
--------
>>> import sparknlp
>>> from sparknlp.base import *
>>> from pyspark.ml import Pipeline
>>> data = spark.read.format("image").load("./tmp/images/").toDF("image")
>>> imageAssembler = ImageAssembler().setInputCol("image").setOutputCol("image_assembler")
>>> result = imageAssembler.transform(data)
>>> result.select("image_assembler").show()
>>> result.select("image_assembler").printSchema()
root
|-- image_assembler: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- annotatorType: string (nullable = true)
| | |-- origin: string (nullable = true)
| | |-- height: integer (nullable = true)
| | |-- width: integer (nullable = true)
| | |-- nChannels: integer (nullable = true)
| | |-- mode: integer (nullable = true)
| | |-- result: binary (nullable = true)
| | |-- metadata: map (nullable = true)
| | | |-- key: string
| | | |-- value: string (valueContainsNull = true)
"""
outputAnnotatorType = AnnotatorType.IMAGE
inputCol = Param(Params._dummy(), "inputCol", "input column name", typeConverter=TypeConverters.toString)
textCol = Param(Params._dummy(), "textCol", "text column name", typeConverter=TypeConverters.toString)
outputCol = Param(Params._dummy(), "outputCol", "output column name", typeConverter=TypeConverters.toString)
name = 'ImageAssembler'
@keyword_only
def __init__(self):
super(ImageAssembler, self).__init__(classname="com.johnsnowlabs.nlp.ImageAssembler")
self._setDefault(outputCol="image_assembler", inputCol='image')
[docs] @keyword_only
def setParams(self):
kwargs = self._input_kwargs
return self._set(**kwargs)
[docs] def setOutputCol(self, value):
"""Sets output column name.
Parameters
----------
value : str
Name of the Output Column
"""
return self._set(outputCol=value)
[docs] def getOutputCol(self):
"""Gets output column name of annotations."""
return self.getOrDefault(self.outputCol)
[docs] def setTextCol(self, value):
"""Sets an optional text column name.
Parameters
----------
value : str
Name of an optional input text column
"""
return self._set(inputCol=value)
[docs] @classmethod
def loadImagesAsBytes(cls, spark: SparkSession, path: str):
"""
Loads images from a given path and returns them as raw bytes, instead of the default
OpenCV-compatible format. Supported image types include JPEG, PNG, GIF, and BMP.
Multimodal inference with llama.cpp requires raw bytes as input.
Parameters
----------
spark : SparkSession
The active SparkSession.
path : str
The path to the images. Supported image types are JPEG, PNG, GIF, and BMP.
Returns
-------
DataFrame
A DataFrame containing the images as raw bytes along with their metadata.
"""
# Replace the path separator in the `origin` field and `path` column, so that they match
def replace_path(column_name: str):
return regexp_replace(col(column_name), ":///", ":/")
# Load the images as metadata with the default Spark image format
data = (
spark.read.format("image")
.option("dropInvalid", True)
.load(path)
.withColumn(
"image", col("image").withField("origin", replace_path("image.origin"))
)
)
# Load the images as raw binary files
image_bytes = (
spark.read.format("binaryFile")
.option("pathGlobFilter", "*.{jpeg,jpg,png,gif,bmp,JPEG,JPG,PNG,GIF,BMP}")
.option("dropInvalid", True)
.load(path)
.withColumn("path", replace_path("path"))
)
# Join the two datasets on the file path
df_joined = data.join(
image_bytes, data["image.origin"] == image_bytes["path"], "inner"
)
# Replace the `data` field of the `image` column with raw bytes
df_image_replaced = df_joined.withColumn(
"image", df_joined["image"].withField("data", df_joined["content"])
)
return df_image_replaced