Token Classification

 

Token classification is a natural language understanding task where labels are assigned to individual tokens in a text. Common subtasks include Named Entity Recognition (NER) and Part-of-Speech (PoS) tagging. For example, NER models can be trained to detect entities like dates, people, and locations, while PoS tagging identifies whether a word functions as a noun, verb, punctuation mark, or another grammatical category.

Picking a Model

When picking a model for token classification, start with the type of task you need—such as Named Entity Recognition (NER) for tagging names of people, places, or organizations, Part-of-Speech (POS) tagging for grammatical structure, or slot filling in chatbots. For small or less complex datasets, lighter models like DistilBERT or pretrained pipelines can give fast and practical results. If you have more data or need higher accuracy, larger models like BERT, RoBERTa, or XLM-R are strong baselines, and domain-specialized versions like BioBERT (for biomedical text) or Legal-BERT (for legal text) often perform best in their fields. Keep in mind trade-offs: smaller models are faster and easier to deploy, while larger transformers provide richer context understanding but come with higher compute costs.

You can explore and select models for your token classification tasks at Spark NLP Models

If existing models do not meet your requirements, you can train your own custom model using the Spark NLP Training Documentation.

How to use

from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline

documentAssembler = DocumentAssembler() \
    .setInputCol("text") \
    .setOutputCol("document")

tokenizer = Tokenizer() \
    .setInputCols(["document"]) \
    .setOutputCol("token")

tokenClassifier = BertForTokenClassification.pretrained() \
    .setInputCols(["token", "document"]) \
    .setOutputCol("label") \
    .setCaseSensitive(True)

pipeline = Pipeline().setStages([
    documentAssembler,
    tokenizer,
    tokenClassifier
])

data = spark.createDataFrame([["John Lenon was born in London and lived in Paris. My name is Sarah and I live in London"]]).toDF("text")

model = pipeline.fit(data)
result = model.transform(data)

result.select("label.result").show(truncate=False)

import spark.implicits._
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.nlp.annotator._
import org.apache.spark.ml.Pipeline

val documentAssembler = new DocumentAssembler()
  .setInputCol("text")
  .setOutputCol("document")

val tokenizer = new Tokenizer()
  .setInputCols("document")
  .setOutputCol("token")

val tokenClassifier = BertForTokenClassification.pretrained()
  .setInputCols("token", "document")
  .setOutputCol("label")
  .setCaseSensitive(true)

val pipeline = new Pipeline().setStages(Array(
  documentAssembler,
  tokenizer,
  tokenClassifier
))

val data = Seq("John Lenon was born in London and lived in Paris. My name is Sarah and I live in London").toDF("text")

val model = pipeline.fit(data)
val result = model.transform(data)

result.select("label.result").show(false)

+------------------------------------------------------------------------------------+
|result                                                                              |
+------------------------------------------------------------------------------------+
|[B-PER, I-PER, O, O, O, B-LOC, O, O, O, B-LOC, O, O, O, O, B-PER, O, O, O, O, B-LOC]|
+------------------------------------------------------------------------------------+

Try Real-Time Demos!

If you want to see the outputs of text classification models in real time, visit our interactive demos:

Useful Resources

Want to dive deeper into text classification with Spark NLP? Here are some curated resources to help you get started and explore further:

Articles and Guides

Notebooks

Training Scripts

Last updated