sparknlp.annotator.cv.clip_for_zero_shot_classification
#
Contains classes concerning CLIPForZeroShotClassification.
Module Contents#
Classes#
Zero Shot Image Classifier based on CLIP. |
- class CLIPForZeroShotClassification(classname='com.johnsnowlabs.nlp.annotators.cv.CLIPForZeroShotClassification', java_model=None)[source]#
Zero Shot Image Classifier based on CLIP.
CLIP (Contrastive Language-Image Pre-Training) is a neural network that was trained on image and text pairs. It has the ability to predict images without training on any hard-coded labels. This makes it very flexible, as labels can be provided during inference. This is similar to the zero-shot capabilities of the GPT-2 and 3 models.
Pretrained models can be loaded with
pretrained
of the companion object:imageClassifier = CLIPForZeroShotClassification.pretrained() \ .setInputCols(["image_assembler"]) \ .setOutputCol("label")
The default model is
"zero_shot_classifier_clip_vit_base_patch32"
, if no name is provided.For available pretrained models please see the Models Hub.
Models from the HuggingFace 🤗 Transformers library are also compatible with Spark NLP 🚀. To see which models are compatible and how to import them see JohnSnowLabs/spark-nlp#5669 and to see more extended examples, see CLIPForZeroShotClassificationTestSpec.
Input Annotation types
Output Annotation type
IMAGE
CATEGORY
- Parameters:
- batchSize
Batch size, by default 2.
- candidateLabels
Array of labels for classification
Examples
>>> import sparknlp >>> from sparknlp.base import * >>> from sparknlp.annotator import * >>> from pyspark.ml import Pipeline >>> imageDF = spark.read \ ... .format("image") \ ... .option("dropInvalid", value = True) \ ... .load("src/test/resources/image/") >>> imageAssembler = ImageAssembler() \ ... .setInputCol("image") \ ... .setOutputCol("image_assembler") >>> candidateLabels = [ ... "a photo of a bird", ... "a photo of a cat", ... "a photo of a dog", ... "a photo of a hen", ... "a photo of a hippo", ... "a photo of a room", ... "a photo of a tractor", ... "a photo of an ostrich", ... "a photo of an ox"] >>> imageClassifier = CLIPForZeroShotClassification \ ... .pretrained() \ ... .setInputCols(["image_assembler"]) \ ... .setOutputCol("label") \ ... .setCandidateLabels(candidateLabels) >>> pipeline = Pipeline().setStages([imageAssembler, imageClassifier]) >>> pipelineDF = pipeline.fit(imageDF).transform(imageDF) >>> pipelineDF \ ... .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "label.result") \ ... .show(truncate=False) +-----------------+-----------------------+ |image_name |result | +-----------------+-----------------------+ |palace.JPEG |[a photo of a room] | |egyptian_cat.jpeg|[a photo of a cat] | |hippopotamus.JPEG|[a photo of a hippo] | |hen.JPEG |[a photo of a hen] | |ostrich.JPEG |[a photo of an ostrich]| |junco.JPEG |[a photo of a bird] | |bluetick.jpg |[a photo of a dog] | |chihuahua.jpg |[a photo of a dog] | |tractor.JPEG |[a photo of a tractor] | |ox.JPEG |[a photo of an ox] | +-----------------+-----------------------+
- static loadSavedModel(folder, spark_session)[source]#
Loads a locally saved model.
- Parameters:
- folderstr
Folder of the saved model
- spark_sessionpyspark.sql.SparkSession
The current SparkSession
- Returns:
- CLIPForZeroShotClassification
The restored model
- static pretrained(name='zero_shot_classifier_clip_vit_base_patch32', lang='en', remote_loc=None)[source]#
Downloads and loads a pretrained model.
- Parameters:
- namestr, optional
Name of the pretrained model, by default “image_classifier_vit_base_patch16_224”
- langstr, optional
Language of the pretrained model, by default “en”
- remote_locstr, optional
Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise.
- Returns:
- CLIPForZeroShotClassification
The restored model