# Copyright 2017-2024 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes for LateChunkEmbeddings."""
from sparknlp.common import *
__all__ = ["LateChunkEmbeddings"]
[docs]class LateChunkEmbeddings(AnnotatorModel):
"""Produces contextual chunk-level embeddings using the **Late Chunking** technique
described in `Jin et al. (2024) <https://arxiv.org/abs/2409.04701>`__.
Unlike :class:`.ChunkEmbeddings`, which embeds each chunk in isolation,
``LateChunkEmbeddings`` expects that the upstream token-embedding stage (e.g.
:class:`.ModernBertEmbeddings` or :class:`.LongformerEmbeddings`) has already
processed the **full document** in a single forward pass, producing contextual
token representations. This annotator then locates the tokens that fall within
each chunk's character span and mean-pools them into a single
``SENTENCE_EMBEDDINGS`` vector — so every chunk embedding is informed by the
complete document context rather than being isolated. By default, token
selection is sentence-aware: selected token embeddings must fall inside the
chunk span and have the same sentence id as the chunk. Set
``sentenceAwareFiltering`` to ``False`` to use span-only filtering.
.. note::
``LateChunkEmbeddings`` **must** appear **after** the token-embedding stage
in the pipeline. Placing it before the embedding stage will raise a runtime
error.
.. note::
The contextual benefit is bounded by the upstream model's maximum sequence
length (e.g. 8 192 tokens for ``ModernBertEmbeddings``). Documents that
exceed this limit are truncated before embedding, which reduces cross-chunk
context for tokens near the end of very long documents.
====================================== ======================
Input Annotation types Output Annotation type
====================================== ======================
``DOCUMENT, CHUNK, WORD_EMBEDDINGS`` ``SENTENCE_EMBEDDINGS``
====================================== ======================
Parameters
----------
poolingStrategy
Strategy to aggregate token embeddings within each chunk span, by default
``AVERAGE``.
Possible values: ``AVERAGE``, ``SUM``
skipOOV
Whether to discard default zero-vectors for OOV tokens from the pool,
by default ``True``.
sentenceAwareFiltering
Whether to restrict token embeddings to the same sentence as the chunk
when pooling, by default ``True``.
References
----------
Jin et al., *Late Chunking: Contextual Chunk Embeddings Using Long-Context
Embedding Models*, arXiv:2409.04701 (2024).
Examples
--------
>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
Build a late-chunking retrieval pipeline
>>> documentAssembler = DocumentAssembler() \\
... .setInputCol("text") \\
... .setOutputCol("document")
>>> tokenizer = Tokenizer() \\
... .setInputCols(["document"]) \\
... .setOutputCol("token")
>>> tokenEmbeddings = ModernBertEmbeddings.pretrained("modernbert-base", "en") \\
... .setInputCols(["document", "token"]) \\
... .setOutputCol("token_embeddings") \\
... .setMaxSentenceLength(8192)
>>> chunker = Doc2Chunk() \\
... .setInputCols(["document"]) \\
... .setChunkCol("chunks") \\
... .setIsArray(True) \\
... .setOutputCol("chunk")
>>> lateChunkEmbeddings = LateChunkEmbeddings() \\
... .setInputCols(["document", "chunk", "token_embeddings"]) \\
... .setOutputCol("late_chunk_embeddings") \\
... .setPoolingStrategy("AVERAGE")
>>> pipeline = Pipeline() \\
... .setStages([
... documentAssembler,
... tokenizer,
... tokenEmbeddings,
... chunker,
... lateChunkEmbeddings
... ])
>>> data = spark.createDataFrame([(
... "AcmeDrug was prescribed for migraine in March. The patient took two doses.\\n\\n"
... "It caused severe nausea the next day, and therapy was stopped.",
... [
... "AcmeDrug was prescribed for migraine in March. The patient took two doses.",
... "It caused severe nausea the next day, and therapy was stopped."
... ]
... )], ["text", "chunks"])
>>> result = pipeline.fit(data).transform(data)
>>> result.selectExpr("explode(late_chunk_embeddings) as r") \\
... .select("r.annotatorType", "r.result", "r.embeddings") \\
... .show(5, 80)
+-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
| annotatorType| result| embeddings|
+-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
|sentence_embeddings|AcmeDrug was prescribed for migraine in March. The patient took two doses.|[0.050471008, -0.07595207, 0.031268876, 0.15105441, -0.013697156, 0.08131724,...|
|sentence_embeddings| It caused severe nausea the next day, and therapy was stopped.|[0.0735685, 0.0060829176, 0.12051964, 0.22399232, 0.055884164, 0.066795066, 0...|
+-------------------+--------------------------------------------------------------------------+--------------------------------------------------------------------------------+
"""
[docs] name = "LateChunkEmbeddings"
[docs] outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
@keyword_only
def __init__(self):
super(LateChunkEmbeddings, self).__init__(
classname="com.johnsnowlabs.nlp.embeddings.LateChunkEmbeddings"
)
self._setDefault(
poolingStrategy="AVERAGE", skipOOV=True, sentenceAwareFiltering=True
)
[docs] poolingStrategy = Param(
Params._dummy(),
"poolingStrategy",
"Strategy to aggregate token embeddings into a chunk embedding: AVERAGE or SUM",
typeConverter=TypeConverters.toString,
)
[docs] skipOOV = Param(
Params._dummy(),
"skipOOV",
"Whether to discard default vectors for OOV words from the aggregation / pooling",
typeConverter=TypeConverters.toBoolean,
)
[docs] sentenceAwareFiltering = Param(
Params._dummy(),
"sentenceAwareFiltering",
"Whether to restrict token embeddings to the same sentence as the chunk when pooling",
typeConverter=TypeConverters.toBoolean,
)
[docs] def setPoolingStrategy(self, strategy):
"""Sets the strategy used to aggregate token embeddings within each chunk span.
Parameters
----------
strategy : str
Pooling strategy. One of ``AVERAGE`` (default) or ``SUM``.
"""
if strategy in ("AVERAGE", "SUM"):
return self._set(poolingStrategy=strategy)
else:
return self._set(poolingStrategy="AVERAGE")
[docs] def setSkipOOV(self, value):
"""Sets whether to discard default zero-vectors for OOV tokens during pooling.
Parameters
----------
value : bool
If ``True`` (default), OOV zero-vectors are excluded from the pool so that
they do not dilute the chunk embedding.
"""
return self._set(skipOOV=value)
[docs] def setSentenceAwareFiltering(self, value):
"""Sets whether token filtering should also require matching sentence id.
Parameters
----------
value : bool
If ``True`` (default), selected token embeddings must fall within the
chunk span and have the same sentence id as the chunk. If ``False``,
filtering is span-only.
"""
return self._set(sentenceAwareFiltering=value)