sparknlp.annotator.cv.clip_for_zero_shot_classification#

Contains classes concerning CLIPForZeroShotClassification.

Module Contents#

Classes#

CLIPForZeroShotClassification

Zero Shot Image Classifier based on CLIP.

class CLIPForZeroShotClassification(classname='com.johnsnowlabs.nlp.annotators.cv.CLIPForZeroShotClassification', java_model=None)[source]#

Zero Shot Image Classifier based on CLIP.

CLIP (Contrastive Language-Image Pre-Training) is a neural network that was trained on image and text pairs. It has the ability to predict images without training on any hard-coded labels. This makes it very flexible, as labels can be provided during inference. This is similar to the zero-shot capabilities of the GPT-2 and 3 models.

Pretrained models can be loaded with pretrained of the companion object:

imageClassifier = CLIPForZeroShotClassification.pretrained() \
    .setInputCols(["image_assembler"]) \
    .setOutputCol("label")

The default model is "zero_shot_classifier_clip_vit_base_patch32", if no name is provided.

For available pretrained models please see the Models Hub.

Models from the HuggingFace 🤗 Transformers library are also compatible with Spark NLP 🚀. To see which models are compatible and how to import them see JohnSnowLabs/spark-nlp#5669 and to see more extended examples, see CLIPForZeroShotClassificationTestSpec.

Input Annotation types

Output Annotation type

IMAGE

CATEGORY

Parameters:
batchSize

Batch size, by default 2.

candidateLabels

Array of labels for classification

Examples

>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
>>> imageDF = spark.read \
...     .format("image") \
...     .option("dropInvalid", value = True) \
...     .load("src/test/resources/image/")
>>> imageAssembler = ImageAssembler() \
...     .setInputCol("image") \
...     .setOutputCol("image_assembler")
>>> candidateLabels = [
...     "a photo of a bird",
...     "a photo of a cat",
...     "a photo of a dog",
...     "a photo of a hen",
...     "a photo of a hippo",
...     "a photo of a room",
...     "a photo of a tractor",
...     "a photo of an ostrich",
...     "a photo of an ox"]
>>> imageClassifier = CLIPForZeroShotClassification \
...     .pretrained() \
...     .setInputCols(["image_assembler"]) \
...     .setOutputCol("label") \
...     .setCandidateLabels(candidateLabels)
>>> pipeline = Pipeline().setStages([imageAssembler, imageClassifier])
>>> pipelineDF = pipeline.fit(imageDF).transform(imageDF)
>>> pipelineDF \
...   .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "label.result") \
...   .show(truncate=False)
+-----------------+-----------------------+
|image_name       |result                 |
+-----------------+-----------------------+
|palace.JPEG      |[a photo of a room]    |
|egyptian_cat.jpeg|[a photo of a cat]     |
|hippopotamus.JPEG|[a photo of a hippo]   |
|hen.JPEG         |[a photo of a hen]     |
|ostrich.JPEG     |[a photo of an ostrich]|
|junco.JPEG       |[a photo of a bird]    |
|bluetick.jpg     |[a photo of a dog]     |
|chihuahua.jpg    |[a photo of a dog]     |
|tractor.JPEG     |[a photo of a tractor] |
|ox.JPEG          |[a photo of an ox]     |
+-----------------+-----------------------+
getCandidateLabels()[source]#

Returns labels used to train this model

static loadSavedModel(folder, spark_session)[source]#

Loads a locally saved model.

Parameters:
folderstr

Folder of the saved model

spark_sessionpyspark.sql.SparkSession

The current SparkSession

Returns:
CLIPForZeroShotClassification

The restored model

static pretrained(name='zero_shot_classifier_clip_vit_base_patch32', lang='en', remote_loc=None)[source]#

Downloads and loads a pretrained model.

Parameters:
namestr, optional

Name of the pretrained model, by default “image_classifier_vit_base_patch16_224”

langstr, optional

Language of the pretrained model, by default “en”

remote_locstr, optional

Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise.

Returns:
CLIPForZeroShotClassification

The restored model