Image Classification

 

Image classification is a way for computers to recognize what an image contains by assigning it a single label, such as “dog”, “car”, or “cat”. The model looks for patterns such as shapes, colors, and textures that distinguish one class from another. It does not locate where the object is in the image or handle multiple objects; it simply identifies the overall category.

In practice, this is used to organize and tag large collections of photos (like in Google Photos or stock image sites), filter content, or power visual search systems. The model’s output usually includes a few possible labels with confidence scores that show how sure it is about each prediction.

Picking a Model

When picking a model for image classification, think about what you are trying to achieve. For simple tasks like recognizing a few object types or when you have limited computing power, lightweight models such as MobileNet, EfficientNet-Lite, or ResNet-18 are good starting points because they are fast and easy to deploy. If you have a larger dataset and need higher accuracy, deeper architectures like ResNet-50, DenseNet, or EfficientNet-B7 generally perform better when properly fine-tuned.

If your images belong to a specific domain, consider using a domain-pretrained model that has been trained on similar data. For example, MedNet is designed for medical imaging, GeoResNet works well for satellite imagery, and CLIP is effective for general-purpose image and text matching. These models often outperform generic ones on domain-specific tasks.

To explore and select from a variety of models, visit Spark NLP Models, where you can find models tailored for different tasks and datasets.

How to use

from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline

imageAssembler = ImageAssembler() \
    .setInputCol("image") \
    .setOutputCol("image_assembler")

imageClassifier = ViTForImageClassification \
    .pretrained() \
    .setInputCols(["image_assembler"]) \
    .setOutputCol("class")

pipeline = Pipeline().setStages([
    imageAssembler, 
    imageClassifier
])

imageDF = spark.read \
    .format("image") \
    .option("dropInvalid", value=True) \
    .load("path/to/images/folder")

model = pipeline.fit(imageDF)
result = model.transform(imageDF)

result \
  .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "class.result") \
  .show(truncate=False)

import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.nlp.annotators._
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.functions._

val imageAssembler = new ImageAssembler()
  .setInputCol("image")
  .setOutputCol("image_assembler")

val imageClassifier = ViTForImageClassification
  .pretrained()
  .setInputCols("image_assembler")
  .setOutputCol("class")

val pipeline = new Pipeline().setStages(Array(
  imageAssembler,
  imageClassifier
))

val imageDF = spark.read
  .format("image")
  .option("dropInvalid", true)
  .load("path/to/images/folder")

val model = pipeline.fit(imageDF)
val result = model.transform(imageDF)

result
  .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "class.result")
  .show(false)

+-----------------+----------------------------------------------------------+
|image_name       |result                                                    |
+-----------------+----------------------------------------------------------+
|palace.JPEG      |[palace]                                                  |
|egyptian_cat.jpeg|[Egyptian cat]                                            |
|hippopotamus.JPEG|[hippopotamus, hippo, river horse, Hippopotamus amphibius]|
|hen.JPEG         |[hen]                                                     |
|ostrich.JPEG     |[ostrich, Struthio camelus]                               |
|junco.JPEG       |[junco, snowbird]                                         |
|bluetick.jpg     |[bluetick]                                                |
|chihuahua.jpg    |[Chihuahua]                                               |
|tractor.JPEG     |[tractor]                                                 |
|ox.JPEG          |[ox]                                                      |
+-----------------+----------------------------------------------------------+

Try Real-Time Demos!

If you want to explore real-time image classification outputs, visit our interactive demos:

Useful Resources

To dive deeper into image classification using Spark NLP, check out these useful resources:

Notebooks

Last updated