sparknlp.annotator.embeddings.late_chunk_embeddings#
Contains classes for LateChunkEmbeddings.
Module Contents#
Classes#
Produces contextual chunk-level embeddings using the Late Chunking technique |
- class LateChunkEmbeddings[source]#
Produces contextual chunk-level embeddings using the Late Chunking technique described in Jin et al. (2024).
Unlike
ChunkEmbeddings, which embeds each chunk in isolation,LateChunkEmbeddingsexpects that the upstream token-embedding stage (e.g.ModernBertEmbeddingsorLongformerEmbeddings) has already processed the full document in a single forward pass, producing contextual token representations. This annotator then locates the tokens that fall within each chunk’s character span and mean-pools them into a singleSENTENCE_EMBEDDINGSvector — so every chunk embedding is informed by the complete document context rather than being isolated. By default, token selection is sentence-aware: selected token embeddings must fall inside the chunk span and have the same sentence id as the chunk. SetsentenceAwareFilteringtoFalseto use span-only filtering.Note
LateChunkEmbeddingsmust appear after the token-embedding stage in the pipeline. Placing it before the embedding stage will raise a runtime error.Note
The contextual benefit is bounded by the upstream model’s maximum sequence length (e.g. 8 192 tokens for
ModernBertEmbeddings). Documents that exceed this limit are truncated before embedding, which reduces cross-chunk context for tokens near the end of very long documents.Input Annotation types
Output Annotation type
DOCUMENT, CHUNK, WORD_EMBEDDINGSSENTENCE_EMBEDDINGS- Parameters:
- poolingStrategy
Strategy to aggregate token embeddings within each chunk span, by default
AVERAGE. Possible values:AVERAGE,SUM- skipOOV
Whether to discard default zero-vectors for OOV tokens from the pool, by default
True.- sentenceAwareFiltering
Whether to restrict token embeddings to the same sentence as the chunk when pooling, by default
True.
References
Jin et al., Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models, arXiv:2409.04701 (2024).
Examples
>>> import sparknlp >>> from sparknlp.base import * >>> from sparknlp.annotator import * >>> from pyspark.ml import Pipeline
Build a late-chunking retrieval pipeline
>>> documentAssembler = DocumentAssembler() \ ... .setInputCol("text") \ ... .setOutputCol("document") >>> tokenizer = Tokenizer() \ ... .setInputCols(["document"]) \ ... .setOutputCol("token") >>> tokenEmbeddings = ModernBertEmbeddings.pretrained("modernbert-base", "en") \ ... .setInputCols(["document", "token"]) \ ... .setOutputCol("token_embeddings") \ ... .setMaxSentenceLength(8192) >>> chunker = Doc2Chunk() \ ... .setInputCols(["document"]) \ ... .setChunkCol("chunks") \ ... .setIsArray(True) \ ... .setOutputCol("chunk") >>> lateChunkEmbeddings = LateChunkEmbeddings() \ ... .setInputCols(["document", "chunk", "token_embeddings"]) \ ... .setOutputCol("late_chunk_embeddings") \ ... .setPoolingStrategy("AVERAGE") >>> pipeline = Pipeline() \ ... .setStages([ ... documentAssembler, ... tokenizer, ... tokenEmbeddings, ... chunker, ... lateChunkEmbeddings ... ]) >>> data = spark.createDataFrame([( ... "AcmeDrug was prescribed for migraine in March. The patient took two doses.\n\n" ... "It caused severe nausea the next day, and therapy was stopped.", ... [ ... "AcmeDrug was prescribed for migraine in March. The patient took two doses.", ... "It caused severe nausea the next day, and therapy was stopped." ... ] ... )], ["text", "chunks"]) >>> result = pipeline.fit(data).transform(data) >>> result.selectExpr("explode(late_chunk_embeddings) as r") \ ... .select("r.annotatorType", "r.result", "r.embeddings") \ ... .show(5, 80) +-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+ | annotatorType| result| embeddings| +-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+ |sentence_embeddings|AcmeDrug was prescribed for migraine in March. The patient took two doses.|[0.050471008, -0.07595207, 0.031268876, 0.15105441, -0.013697156, 0.08131724,...| |sentence_embeddings| It caused severe nausea the next day, and therapy was stopped.|[0.0735685, 0.0060829176, 0.12051964, 0.22399232, 0.055884164, 0.066795066, 0...| +-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
- setPoolingStrategy(strategy)[source]#
Sets the strategy used to aggregate token embeddings within each chunk span.
- Parameters:
- strategystr
Pooling strategy. One of
AVERAGE(default) orSUM.