sparknlp.annotator.embeddings.late_chunk_embeddings#

Contains classes for LateChunkEmbeddings.

Module Contents#

Classes#

LateChunkEmbeddings

Produces contextual chunk-level embeddings using the Late Chunking technique

class LateChunkEmbeddings[source]#

Produces contextual chunk-level embeddings using the Late Chunking technique described in Jin et al. (2024).

Unlike ChunkEmbeddings, which embeds each chunk in isolation, LateChunkEmbeddings expects that the upstream token-embedding stage (e.g. ModernBertEmbeddings or LongformerEmbeddings) has already processed the full document in a single forward pass, producing contextual token representations. This annotator then locates the tokens that fall within each chunk’s character span and mean-pools them into a single SENTENCE_EMBEDDINGS vector — so every chunk embedding is informed by the complete document context rather than being isolated. By default, token selection is sentence-aware: selected token embeddings must fall inside the chunk span and have the same sentence id as the chunk. Set sentenceAwareFiltering to False to use span-only filtering.

Note

LateChunkEmbeddings must appear after the token-embedding stage in the pipeline. Placing it before the embedding stage will raise a runtime error.

Note

The contextual benefit is bounded by the upstream model’s maximum sequence length (e.g. 8 192 tokens for ModernBertEmbeddings). Documents that exceed this limit are truncated before embedding, which reduces cross-chunk context for tokens near the end of very long documents.

Input Annotation types

Output Annotation type

DOCUMENT, CHUNK, WORD_EMBEDDINGS

SENTENCE_EMBEDDINGS

Parameters:
poolingStrategy

Strategy to aggregate token embeddings within each chunk span, by default AVERAGE. Possible values: AVERAGE, SUM

skipOOV

Whether to discard default zero-vectors for OOV tokens from the pool, by default True.

sentenceAwareFiltering

Whether to restrict token embeddings to the same sentence as the chunk when pooling, by default True.

References

Jin et al., Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models, arXiv:2409.04701 (2024).

Examples

>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline

Build a late-chunking retrieval pipeline

>>> documentAssembler = DocumentAssembler() \
...     .setInputCol("text") \
...     .setOutputCol("document")
>>> tokenizer = Tokenizer() \
...     .setInputCols(["document"]) \
...     .setOutputCol("token")
>>> tokenEmbeddings = ModernBertEmbeddings.pretrained("modernbert-base", "en") \
...     .setInputCols(["document", "token"]) \
...     .setOutputCol("token_embeddings") \
...     .setMaxSentenceLength(8192)
>>> chunker = Doc2Chunk() \
...     .setInputCols(["document"]) \
...     .setChunkCol("chunks") \
...     .setIsArray(True) \
...     .setOutputCol("chunk")
>>> lateChunkEmbeddings = LateChunkEmbeddings() \
...     .setInputCols(["document", "chunk", "token_embeddings"]) \
...     .setOutputCol("late_chunk_embeddings") \
...     .setPoolingStrategy("AVERAGE")
>>> pipeline = Pipeline() \
...     .setStages([
...       documentAssembler,
...       tokenizer,
...       tokenEmbeddings,
...       chunker,
...       lateChunkEmbeddings
...     ])
>>> data = spark.createDataFrame([(
...     "AcmeDrug was prescribed for migraine in March. The patient took two doses.\n\n"
...     "It caused severe nausea the next day, and therapy was stopped.",
...     [
...         "AcmeDrug was prescribed for migraine in March. The patient took two doses.",
...         "It caused severe nausea the next day, and therapy was stopped."
...     ]
... )], ["text", "chunks"])
>>> result = pipeline.fit(data).transform(data)
>>> result.selectExpr("explode(late_chunk_embeddings) as r") \
...     .select("r.annotatorType", "r.result", "r.embeddings") \
...     .show(5, 80)
+-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
|      annotatorType|                                                                    result|                                                                      embeddings|
+-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
|sentence_embeddings|AcmeDrug was prescribed for migraine in March. The patient took two doses.|[0.050471008, -0.07595207, 0.031268876, 0.15105441, -0.013697156, 0.08131724,...|
|sentence_embeddings|            It caused severe nausea the next day, and therapy was stopped.|[0.0735685, 0.0060829176, 0.12051964, 0.22399232, 0.055884164, 0.066795066, 0...|
+-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
name = 'LateChunkEmbeddings'[source]#
inputAnnotatorTypes[source]#
outputAnnotatorType = 'sentence_embeddings'[source]#
poolingStrategy[source]#
skipOOV[source]#
sentenceAwareFiltering[source]#
setPoolingStrategy(strategy)[source]#

Sets the strategy used to aggregate token embeddings within each chunk span.

Parameters:
strategystr

Pooling strategy. One of AVERAGE (default) or SUM.

setSkipOOV(value)[source]#

Sets whether to discard default zero-vectors for OOV tokens during pooling.

Parameters:
valuebool

If True (default), OOV zero-vectors are excluded from the pool so that they do not dilute the chunk embedding.

setSentenceAwareFiltering(value)[source]#

Sets whether token filtering should also require matching sentence id.

Parameters:
valuebool

If True (default), selected token embeddings must fall within the chunk span and have the same sentence id as the chunk. If False, filtering is span-only.