Text Classification

 

Text classification is a natural language processing task where entire pieces of text, such as sentences, paragraphs, or documents, are assigned predefined labels. Common subtasks include sentiment analysis and topic classification. For instance, sentiment analysis models can determine whether a review is positive, negative, or neutral, while topic classification models can categorize news articles into areas like politics, sports, or technology.

Picking a Model

When picking a model for text classification, think about what you’re trying to do—like detecting spam (binary), analyzing sentiment across positive, neutral, and negative (multiclass), or tagging a news article with several topics at once (multi-label). If you only have a small dataset, lightweight options such as DistilBERT are quick to train and deploy, while larger transformers like BERT or RoBERTa generally give better accuracy when fine-tuned on enough data. For specialized fields, domain-trained models like BioBERT for biomedical research or FinBERT for finance usually outperform general-purpose ones. Finally, keep in mind practical constraints—how much compute you have, whether you need real-time predictions, how important explainability is, and what balance you want between speed, cost, and accuracy.

To explore and select from a variety of models, visit Spark NLP Models, where you can find models tailored for different tasks and datasets.

If you have specific needs that are not covered by existing models, you can train your own model tailored to your unique requirements. Follow the guidelines provided in the Spark NLP Training Documentation to get started.

How to use

from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline

documentAssembler = DocumentAssembler() \
    .setInputCol("text") \
    .setOutputCol("document")

tokenizer = Tokenizer() \
    .setInputCols(["document"]) \
    .setOutputCol("token")

sequenceClassifier = BertForSequenceClassification.pretrained() \
    .setInputCols(["token", "document"]) \
    .setOutputCol("label") \
    .setCaseSensitive(True)

pipeline = Pipeline().setStages([
    documentAssembler,
    tokenizer,
    sequenceClassifier
])

data = spark.createDataFrame([
    ("I loved this movie when I was a child.",),
    ("It was pretty boring.",)
]).toDF("text")

model = pipeline.fit(data)
result = model.transform(data)

result.select("text", "label.result").show(truncate=False)

import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.nlp.annotator._
import org.apache.spark.ml.Pipeline
import spark.implicits._

val documentAssembler = new DocumentAssembler()
  .setInputCol("text")
  .setOutputCol("document")

val tokenizer = new Tokenizer()
  .setInputCols("document")
  .setOutputCol("token")

val sequenceClassifier = BertForSequenceClassification.pretrained()
  .setInputCols("token", "document")
  .setOutputCol("label")
  .setCaseSensitive(true)

val pipeline = new Pipeline().setStages(Array(
  documentAssembler,
  tokenizer,
  sequenceClassifier
))

val data = Seq("I loved this movie when I was a child.", "It was pretty boring.").toDF("text")

val model = pipeline.fit(data)
val result = model.transform(data)

result.select("text", "label.result").show(truncate=False)

+--------------------------------------+------+
|text                                  |result|
+--------------------------------------+------+
|I loved this movie when I was a child.|[pos] |
|It was pretty boring.                 |[neg] |
+--------------------------------------+------+

Try Real-Time Demos!

If you want to see the outputs of text classification models in real time, visit our interactive demos:

Useful Resources

Want to dive deeper into text classification with Spark NLP? Here are some curated resources to help you get started and explore further:

Articles and Guides

Notebooks

Training Scripts

Last updated