Source code for sparknlp.annotator.cv.clip_for_zero_shot_classification

#  Copyright 2017-2022 John Snow Labs
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Contains classes concerning CLIPForZeroShotClassification."""

from sparknlp.common import *


[docs]class CLIPForZeroShotClassification(AnnotatorModel, HasBatchedAnnotateImage, HasImageFeatureProperties, HasEngine, HasCandidateLabelsProperties, HasRescaleFactor): """Zero Shot Image Classifier based on CLIP. CLIP (Contrastive Language-Image Pre-Training) is a neural network that was trained on image and text pairs. It has the ability to predict images without training on any hard-coded labels. This makes it very flexible, as labels can be provided during inference. This is similar to the zero-shot capabilities of the GPT-2 and 3 models. Pretrained models can be loaded with ``pretrained`` of the companion object: .. code-block:: python imageClassifier = CLIPForZeroShotClassification.pretrained() \\ .setInputCols(["image_assembler"]) \\ .setOutputCol("label") The default model is ``"zero_shot_classifier_clip_vit_base_patch32"``, if no name is provided. For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Zero-Shot+Classification>`__. Models from the HuggingFace 🤗 Transformers library are also compatible with Spark NLP 🚀. To see which models are compatible and how to import them see https://github.com/JohnSnowLabs/spark-nlp/discussions/5669 and to see more extended examples, see `CLIPForZeroShotClassificationTestSpec <https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/cv/CLIPForZeroShotClassificationTestSpec.scala>`__. ====================== ====================== Input Annotation types Output Annotation type ====================== ====================== ``IMAGE`` ``CATEGORY`` ====================== ====================== Parameters ---------- batchSize Batch size, by default `2`. candidateLabels Array of labels for classification Examples -------- >>> import sparknlp >>> from sparknlp.base import * >>> from sparknlp.annotator import * >>> from pyspark.ml import Pipeline >>> imageDF = spark.read \\ ... .format("image") \\ ... .option("dropInvalid", value = True) \\ ... .load("src/test/resources/image/") >>> imageAssembler = ImageAssembler() \\ ... .setInputCol("image") \\ ... .setOutputCol("image_assembler") >>> candidateLabels = [ ... "a photo of a bird", ... "a photo of a cat", ... "a photo of a dog", ... "a photo of a hen", ... "a photo of a hippo", ... "a photo of a room", ... "a photo of a tractor", ... "a photo of an ostrich", ... "a photo of an ox"] >>> imageClassifier = CLIPForZeroShotClassification \\ ... .pretrained() \\ ... .setInputCols(["image_assembler"]) \\ ... .setOutputCol("label") \\ ... .setCandidateLabels(candidateLabels) >>> pipeline = Pipeline().setStages([imageAssembler, imageClassifier]) >>> pipelineDF = pipeline.fit(imageDF).transform(imageDF) >>> pipelineDF \\ ... .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "label.result") \\ ... .show(truncate=False) +-----------------+-----------------------+ |image_name |result | +-----------------+-----------------------+ |palace.JPEG |[a photo of a room] | |egyptian_cat.jpeg|[a photo of a cat] | |hippopotamus.JPEG|[a photo of a hippo] | |hen.JPEG |[a photo of a hen] | |ostrich.JPEG |[a photo of an ostrich]| |junco.JPEG |[a photo of a bird] | |bluetick.jpg |[a photo of a dog] | |chihuahua.jpg |[a photo of a dog] | |tractor.JPEG |[a photo of a tractor] | |ox.JPEG |[a photo of an ox] | +-----------------+-----------------------+ """ name = "CLIPForZeroShotClassification" inputAnnotatorTypes = [AnnotatorType.IMAGE] outputAnnotatorType = AnnotatorType.CATEGORY configProtoBytes = Param(Params._dummy(), "configProtoBytes", "ConfigProto from tensorflow, serialized into byte array. Get with " "config_proto.SerializeToString()", TypeConverters.toListInt)
[docs] def getCandidateLabels(self): """ Returns labels used to train this model """ return self._call_java("getCandidateLabels")
@keyword_only def __init__(self, classname="com.johnsnowlabs.nlp.annotators.cv.CLIPForZeroShotClassification", java_model=None): super(CLIPForZeroShotClassification, self).__init__( classname=classname, java_model=java_model ) self._setDefault( batchSize=2, doNormalize=True, doRescale=True, doResize=True, imageMean=[0.48145466, 0.4578275, 0.40821073], imageStd=[0.26862954, 0.26130258, 0.27577711], resample=2, rescaleFactor=1 / 255.0, size=224 ) @staticmethod
[docs] def loadSavedModel(folder, spark_session): """Loads a locally saved model. Parameters ---------- folder : str Folder of the saved model spark_session : pyspark.sql.SparkSession The current SparkSession Returns ------- CLIPForZeroShotClassification The restored model """ from sparknlp.internal import _CLIPForZeroShotClassification jModel = _CLIPForZeroShotClassification(folder, spark_session._jsparkSession)._java_obj return CLIPForZeroShotClassification(java_model=jModel)
@staticmethod
[docs] def pretrained(name="zero_shot_classifier_clip_vit_base_patch32", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters ---------- name : str, optional Name of the pretrained model, by default "image_classifier_vit_base_patch16_224" lang : str, optional Language of the pretrained model, by default "en" remote_loc : str, optional Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise. Returns ------- CLIPForZeroShotClassification The restored model """ from sparknlp.pretrained import ResourceDownloader return ResourceDownloader.downloadModel(CLIPForZeroShotClassification, name, lang, remote_loc)