# Copyright 2017-2022 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes for the ContextSpellChecker."""
from sparknlp.common import *
[docs]class ContextSpellCheckerApproach(AnnotatorApproach):
"""Trains a deep-learning based Noisy Channel Model Spell Algorithm.
Correction candidates are extracted combining context information and word
information.
For instantiated/pretrained models, see :class:`.ContextSpellCheckerModel`.
Spell Checking is a sequence to sequence mapping problem. Given an input
sequence, potentially containing a certain number of errors,
``ContextSpellChecker`` will rank correction sequences according to three
things:
#. Different correction candidates for each word — **word level**.
#. The surrounding text of each word, i.e. it’s context —
**sentence level**.
#. The relative cost of different correction candidates according to the
edit operations at the character level it requires — **subword level**.
For extended examples of usage, see the article
`Training a Contextual Spell Checker for Italian Language <https://towardsdatascience.com/training-a-contextual-spell-checker-for-italian-language-66dda528e4bf>`__,
the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/italian/Training_Context_Spell_Checker_Italian.ipynb>`__.
====================== ======================
Input Annotation types Output Annotation type
====================== ======================
``TOKEN`` ``TOKEN``
====================== ======================
Parameters
----------
languageModelClasses
Number of classes to use during factorization of the softmax output in
the LM.
wordMaxDistance
Maximum distance for the generated candidates for every word.
maxCandidates
Maximum number of candidates for every word.
caseStrategy
What case combinations to try when generating candidates, by default 2.
Possible values are:
- 0: All uppercase letters
- 1: First letter capitalized
- 2: All letters
errorThreshold
Threshold perplexity for a word to be considered as an error.
epochs
Number of epochs to train the language model.
batchSize
Batch size for the training in NLM.
initialRate
Initial learning rate for the LM.
finalRate
Final learning rate for the LM.
validationFraction
Percentage of datapoints to use for validation.
minCount
Min number of times a token should appear to be included in vocab.
compoundCount
Min number of times a compound word should appear to be included in
vocab.
classCount
Min number of times the word need to appear in corpus to not be
considered of a special class.
tradeoff
Tradeoff between the cost of a word error and a transition in the
language model.
weightedDistPath
The path to the file containing the weights for the levenshtein
distance.
maxWindowLen
Maximum size for the window used to remember history prior to every
correction.
configProtoBytes
ConfigProto from tensorflow, serialized into byte array.
maxSentLen
Maximum length for a sentence - internal use during training.
graphFolder
Folder path that contain external graph files.
References
----------
For an in-depth explanation of the module see the article
`Applying Context Aware Spell Checking in Spark NLP <https://medium.com/spark-nlp/applying-context-aware-spell-checking-in-spark-nlp-3c29c46963bc>`__.
Examples
--------
>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
For this example, we use the first Sherlock Holmes book as the training dataset.
>>> documentAssembler = DocumentAssembler() \\
... .setInputCol("text") \\
... .setOutputCol("document")
>>> tokenizer = Tokenizer() \\
... .setInputCols("document") \\
... .setOutputCol("token")
>>> spellChecker = ContextSpellCheckerApproach() \\
... .setInputCols("token") \\
... .setOutputCol("corrected") \\
... .setWordMaxDistance(3) \\
... .setBatchSize(24) \\
... .setEpochs(8) \\
... .setLanguageModelClasses(1650) # dependant on vocabulary size
... # .addVocabClass("_NAME_", names) # Extra classes for correction could be added like this
>>> pipeline = Pipeline().setStages([
... documentAssembler,
... tokenizer,
... spellChecker
... ])
>>> path = "sherlockholmes.txt"
>>> dataset = spark.read.text(path) \\
... .toDF("text")
>>> pipelineModel = pipeline.fit(dataset)
See Also
--------
NorvigSweetingApproach, SymmetricDeleteApproach : For alternative approaches to spell checking
"""
name = "ContextSpellCheckerApproach"
inputAnnotatorTypes = [AnnotatorType.TOKEN]
outputAnnotatorType = AnnotatorType.TOKEN
languageModelClasses = Param(Params._dummy(),
"languageModelClasses",
"Number of classes to use during factorization of the softmax output in the LM.",
typeConverter=TypeConverters.toInt)
wordMaxDistance = Param(Params._dummy(),
"wordMaxDistance",
"Maximum distance for the generated candidates for every word.",
typeConverter=TypeConverters.toInt)
maxCandidates = Param(Params._dummy(),
"maxCandidates",
"Maximum number of candidates for every word.",
typeConverter=TypeConverters.toInt)
caseStrategy = Param(Params._dummy(),
"caseStrategy",
"What case combinations to try when generating candidates.",
typeConverter=TypeConverters.toInt)
errorThreshold = Param(Params._dummy(),
"errorThreshold",
"Threshold perplexity for a word to be considered as an error.",
typeConverter=TypeConverters.toFloat)
epochs = Param(Params._dummy(),
"epochs",
"Number of epochs to train the language model.",
typeConverter=TypeConverters.toInt)
batchSize = Param(Params._dummy(),
"batchSize",
"Batch size for the training in NLM.",
typeConverter=TypeConverters.toInt)
initialRate = Param(Params._dummy(),
"initialRate",
"Initial learning rate for the LM.",
typeConverter=TypeConverters.toFloat)
finalRate = Param(Params._dummy(),
"finalRate",
"Final learning rate for the LM.",
typeConverter=TypeConverters.toFloat)
validationFraction = Param(Params._dummy(),
"validationFraction",
"Percentage of datapoints to use for validation.",
typeConverter=TypeConverters.toFloat)
minCount = Param(Params._dummy(),
"minCount",
"Min number of times a token should appear to be included in vocab.",
typeConverter=TypeConverters.toFloat)
compoundCount = Param(Params._dummy(),
"compoundCount",
"Min number of times a compound word should appear to be included in vocab.",
typeConverter=TypeConverters.toInt)
classCount = Param(Params._dummy(),
"classCount",
"Min number of times the word need to appear in corpus to not be considered of a special class.",
typeConverter=TypeConverters.toFloat)
tradeoff = Param(Params._dummy(),
"tradeoff",
"Tradeoff between the cost of a word error and a transition in the language model.",
typeConverter=TypeConverters.toFloat)
weightedDistPath = Param(Params._dummy(),
"weightedDistPath",
"The path to the file containing the weights for the levenshtein distance.",
typeConverter=TypeConverters.toString)
maxWindowLen = Param(Params._dummy(),
"maxWindowLen",
"Maximum size for the window used to remember history prior to every correction.",
typeConverter=TypeConverters.toInt)
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
TypeConverters.toListInt)
maxSentLen = Param(Params._dummy(),
"maxSentLen",
"Maximum length of a sentence to be considered for training.",
typeConverter=TypeConverters.toInt)
graphFolder = Param(Params._dummy(),
"graphFolder",
"Folder path that contain external graph files.",
typeConverter=TypeConverters.toString)
[docs] def setLanguageModelClasses(self, count):
"""Sets number of classes to use during factorization of the softmax
output in the Language Model.
Parameters
----------
count : int
Number of classes
"""
return self._set(languageModelClasses=count)
[docs] def setWordMaxDistance(self, dist):
"""Sets maximum distance for the generated candidates for every word.
Parameters
----------
dist : int
Maximum distance for the generated candidates for every word
"""
return self._set(wordMaxDistance=dist)
[docs] def setMaxCandidates(self, candidates):
"""Sets maximum number of candidates for every word.
Parameters
----------
candidates : int
Maximum number of candidates for every word.
"""
return self._set(maxCandidates=candidates)
[docs] def setCaseStrategy(self, strategy):
"""Sets what case combinations to try when generating candidates.
Possible values are:
- 0: All uppercase letters
- 1: First letter capitalized
- 2: All letters
Parameters
----------
strategy : int
Case combinations to try when generating candidates
"""
return self._set(caseStrategy=strategy)
[docs] def setErrorThreshold(self, threshold):
"""Sets threshold perplexity for a word to be considered as an error.
Parameters
----------
threshold : float
Threshold perplexity for a word to be considered as an error
"""
return self._set(errorThreshold=threshold)
[docs] def setEpochs(self, count):
"""Sets number of epochs to train the language model.
Parameters
----------
count : int
Number of epochs
"""
return self._set(epochs=count)
[docs] def setBatchSize(self, size):
"""Sets batch size.
Parameters
----------
size : int
Batch size
"""
return self._set(batchSize=size)
[docs] def setInitialRate(self, rate):
"""Sets initial learning rate for the LM.
Parameters
----------
rate : float
Initial learning rate for the LM
"""
return self._set(initialRate=rate)
[docs] def setFinalRate(self, rate):
"""Sets final learning rate for the LM.
Parameters
----------
rate : float
Final learning rate for the LM
"""
return self._set(finalRate=rate)
[docs] def setValidationFraction(self, fraction):
"""Sets percentage of datapoints to use for validation.
Parameters
----------
fraction : float
Percentage of datapoints to use for validation
"""
return self._set(validationFraction=fraction)
[docs] def setMinCount(self, count):
"""Sets min number of times a token should appear to be included in
vocab.
Parameters
----------
count : float
Min number of times a token should appear to be included in vocab
"""
return self._set(minCount=count)
[docs] def setCompoundCount(self, count):
"""Sets min number of times a compound word should appear to be included
in vocab.
Parameters
----------
count : int
Min number of times a compound word should appear to be included in
vocab.
"""
return self._set(compoundCount=count)
[docs] def setClassCount(self, count):
"""Sets min number of times the word need to appear in corpus to not be
considered of a special class.
Parameters
----------
count : float
Min number of times the word need to appear in corpus to not be
considered of a special class.
"""
return self._set(classCount=count)
[docs] def setTradeoff(self, alpha):
"""Sets tradeoff between the cost of a word error and a transition in
the language model.
Parameters
----------
alpha : float
Tradeoff between the cost of a word error and a transition in the
language model
"""
return self._set(tradeoff=alpha)
[docs] def setWeightedDistPath(self, path):
"""Sets the path to the file containing the weights for the levenshtein
distance.
Parameters
----------
path : str
Path to the file containing the weights for the levenshtein
distance.
"""
return self._set(weightedDistPath=path)
[docs] def setMaxWindowLen(self, length):
"""Sets the maximum size for the window used to remember history prior
to every correction.
Parameters
----------
length : int
Maximum size for the window used to remember history prior to
every correction
"""
return self._set(maxWindowLen=length)
[docs] def setConfigProtoBytes(self, b):
"""Sets configProto from tensorflow, serialized into byte array.
Parameters
----------
b : List[int]
ConfigProto from tensorflow, serialized into byte array
"""
return self._set(configProtoBytes=b)
[docs] def setGraphFolder(self, path):
"""Sets folder path that contain external graph files.
Parameters
----------
path : str
Folder path that contain external graph files.
"""
return self._set(graphFolder=path)
[docs] def setMaxSentLen(self, sentlen):
"""Sets the maximum length of a sentence.
Parameters
----------
sentlen : int
Maximum length of a sentence
"""
return self._set(maxSentLen=sentlen)
[docs] def addVocabClass(self, label, vocab, userdist=3):
"""Adds a new class of words to correct, based on a vocabulary.
Parameters
----------
label : str
Name of the class
vocab : List[str]
Vocabulary as a list
userdist : int, optional
Maximal distance to the word, by default 3
"""
self._call_java('addVocabClass', label, vocab, userdist)
return self
[docs] def addRegexClass(self, label, regex, userdist=3):
"""Adds a new class of words to correct, based on regex.
Parameters
----------
label : str
Name of the class
regex : str
Regex to add
userdist : int, optional
Maximal distance to the word, by default 3
"""
self._call_java('addRegexClass', label, regex, userdist)
return self
@keyword_only
def __init__(self):
super(ContextSpellCheckerApproach, self). \
__init__(classname="com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerApproach")
def _create_model(self, java_model):
return ContextSpellCheckerModel(java_model=java_model)
[docs]class ContextSpellCheckerModel(AnnotatorModel, HasEngine):
"""Implements a deep-learning based Noisy Channel Model Spell Algorithm.
Correction candidates are extracted combining context information and word
information.
Spell Checking is a sequence to sequence mapping problem. Given an input
sequence, potentially containing a certain number of errors,
``ContextSpellChecker`` will rank correction sequences according to three
things:
#. Different correction candidates for each word — **word level**.
#. The surrounding text of each word, i.e. it’s context —
**sentence level**.
#. The relative cost of different correction candidates according to the
edit operations at the character level it requires — **subword level**.
This is the instantiated model of the :class:`.ContextSpellCheckerApproach`.
For training your own model, please see the documentation of that class.
Pretrained models can be loaded with :meth:`.pretrained` of the companion
object:
>>> spellChecker = ContextSpellCheckerModel.pretrained() \\
... .setInputCols(["token"]) \\
... .setOutputCol("checked")
The default model is ``"spellcheck_dl"``, if no name is provided.
For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Spell+Check>`__.
For extended examples of usage, see the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/italian/Training_Context_Spell_Checker_Italian.ipynb>`__.
====================== ======================
Input Annotation types Output Annotation type
====================== ======================
``TOKEN`` ``TOKEN``
====================== ======================
Parameters
----------
wordMaxDistance
Maximum distance for the generated candidates for every word.
maxCandidates
Maximum number of candidates for every word.
caseStrategy
What case combinations to try when generating candidates.
errorThreshold
Threshold perplexity for a word to be considered as an error.
tradeoff
Tradeoff between the cost of a word error and a transition in the
language model.
maxWindowLen
Maximum size for the window used to remember history prior to every
correction.
gamma
Controls the influence of individual word frequency in the decision.
correctSymbols
Whether to correct special symbols or skip spell checking for them
compareLowcase
If true will compare tokens in low case with vocabulary.
configProtoBytes
ConfigProto from tensorflow, serialized into byte array.
vocabFreq
Frequency words from the vocabulary.
idsVocab
Mapping of ids to vocabulary.
vocabIds
Mapping of vocabulary to ids.
classes
Classes the spell checker recognizes.
weights
Levenshtein weights.
useNewLines
When set to true new lines will be treated as any other character. When set to false correction is applied on paragraphs as defined by newline characters.
References
-------------
For an in-depth explanation of the module see the article `Applying Context
Aware Spell Checking in Spark NLP
<https://medium.com/spark-nlp/applying-context-aware-spell-checking-in-spark-nlp-3c29c46963bc>`__.
Examples
--------
>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
>>> documentAssembler = DocumentAssembler() \\
... .setInputCol("text") \\
... .setOutputCol("doc")
>>> tokenizer = Tokenizer() \\
... .setInputCols(["doc"]) \\
... .setOutputCol("token")
>>> spellChecker = ContextSpellCheckerModel \\
... .pretrained() \\
... .setTradeoff(12.0) \\
... .setInputCols("token") \\
... .setOutputCol("checked")
>>> pipeline = Pipeline().setStages([
... documentAssembler,
... tokenizer,
... spellChecker
... ])
>>> data = spark.createDataFrame([["It was a cold , dreary day and the country was white with smow ."]]).toDF("text")
>>> result = pipeline.fit(data).transform(data)
>>> result.select("checked.result").show(truncate=False)
+--------------------------------------------------------------------------------+
|result |
+--------------------------------------------------------------------------------+
|[It, was, a, cold, ,, dreary, day, and, the, country, was, white, with, snow, .]|
+--------------------------------------------------------------------------------+
See Also
--------
NorvigSweetingModel, SymmetricDeleteModel: For alternative approaches to spell checking
"""
name = "ContextSpellCheckerModel"
inputAnnotatorTypes = [AnnotatorType.TOKEN]
outputAnnotatorType = AnnotatorType.TOKEN
wordMaxDistance = Param(Params._dummy(),
"wordMaxDistance",
"Maximum distance for the generated candidates for every word.",
typeConverter=TypeConverters.toInt)
maxCandidates = Param(Params._dummy(),
"maxCandidates",
"Maximum number of candidates for every word.",
typeConverter=TypeConverters.toInt)
caseStrategy = Param(Params._dummy(),
"caseStrategy",
"What case combinations to try when generating candidates.",
typeConverter=TypeConverters.toInt)
errorThreshold = Param(Params._dummy(),
"errorThreshold",
"Threshold perplexity for a word to be considered as an error.",
typeConverter=TypeConverters.toFloat)
tradeoff = Param(Params._dummy(),
"tradeoff",
"Tradeoff between the cost of a word error and a transition in the language model.",
typeConverter=TypeConverters.toFloat)
maxWindowLen = Param(Params._dummy(),
"maxWindowLen",
"Maximum size for the window used to remember history prior to every correction.",
typeConverter=TypeConverters.toInt)
gamma = Param(Params._dummy(),
"gamma",
"Controls the influence of individual word frequency in the decision.",
typeConverter=TypeConverters.toFloat)
correctSymbols = Param(Params._dummy(), "correctSymbols",
"Whether to correct special symbols or skip spell checking for them",
typeConverter=TypeConverters.toBoolean)
compareLowcase = Param(Params._dummy(), "compareLowcase", "If true will compare tokens in low case with vocabulary",
typeConverter=TypeConverters.toBoolean)
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
TypeConverters.toListInt)
vocabFreq = Param(
Params._dummy(),
"vocabFreq",
"Frequency words from the vocabulary.",
TypeConverters.identity,
)
idsVocab = Param(
Params._dummy(),
"idsVocab",
"Mapping of ids to vocabulary.",
TypeConverters.identity,
)
vocabIds = Param(
Params._dummy(),
"vocabIds",
"Mapping of vocabulary to ids.",
TypeConverters.identity,
)
classes = Param(
Params._dummy(),
"classes",
"Classes the spell checker recognizes.",
TypeConverters.identity,
)
[docs] def setWordMaxDistance(self, dist):
"""Sets maximum distance for the generated candidates for every word.
Parameters
----------
dist : int
Maximum distance for the generated candidates for every word.
"""
return self._set(wordMaxDistance=dist)
[docs] def setMaxCandidates(self, candidates):
"""Sets maximum number of candidates for every word.
Parameters
----------
candidates : int
Maximum number of candidates for every word.
"""
return self._set(maxCandidates=candidates)
[docs] def setCaseStrategy(self, strategy):
"""Sets what case combinations to try when generating candidates.
Parameters
----------
strategy : int
Case combinations to try when generating candidates.
"""
return self._set(caseStrategy=strategy)
[docs] def setErrorThreshold(self, threshold):
"""Sets threshold perplexity for a word to be considered as an error.
Parameters
----------
threshold : float
Threshold perplexity for a word to be considered as an error
"""
return self._set(errorThreshold=threshold)
[docs] def setTradeoff(self, alpha):
"""Sets tradeoff between the cost of a word error and a transition in the
language model.
Parameters
----------
alpha : float
Tradeoff between the cost of a word error and a transition in the
language model
"""
return self._set(tradeoff=alpha)
[docs] def setWeights(self, weights):
"""Sets weights of each word for Levenshtein distance.
Parameters
----------
weights : Dict[str, float]
Weights for Levenshtein distance as a mapping.
"""
self._call_java('setWeights', weights)
[docs] def setMaxWindowLen(self, length):
"""Sets the maximum size for the window used to remember history prior to
every correction.
Parameters
----------
length : int
Maximum size for the window used to remember history prior to
every correction
"""
return self._set(maxWindowLen=length)
[docs] def setGamma(self, g):
"""Sets the influence of individual word frequency in the decision.
Parameters
----------
g : float
Controls the influence of individual word frequency in the decision.
"""
return self._set(gamma=g)
[docs] def setConfigProtoBytes(self, b):
"""Sets configProto from tensorflow, serialized into byte array.
Parameters
----------
b : List[int]
ConfigProto from tensorflow, serialized into byte array
"""
return self._set(configProtoBytes=b)
[docs] def setVocabFreq(self, value: dict):
"""Sets frequency words from the vocabulary.
Parameters
----------
value : dict
Frequency words from the vocabulary.
"""
return self._set(vocabFreq=value)
[docs] def setIdsVocab(self, idsVocab: dict):
"""Sets mapping of ids to vocabulary.
Parameters
----------
idsVocab : dict
Mapping of ids to vocabulary.
"""
return self._set(idsVocab=idsVocab)
[docs] def setVocabIds(self, vocabIds: dict):
"""Sets mapping of vocabulary to ids.
Parameters
----------
vocabIds : dict
Mapping of vocabulary to ids.
"""
return self._set(vocabIds=vocabIds)
[docs] def setClasses(self, value):
"""Sets classes the spell checker recognizes.
Parameters
----------
value : list
Classes the spell checker recognizes.
"""
return self._set(classes=value)
[docs] def getWordClasses(self):
"""Gets the classes of words to be corrected.
Returns
-------
List[str]
Classes of words to be corrected
"""
it = self._call_java('getWordClasses').toIterator()
result = []
while (it.hasNext()):
result.append(it.next().toString())
return result
[docs] def updateRegexClass(self, label, regex):
"""Update existing class to correct, based on regex
Parameters
----------
label : str
Label of the class
regex : str
Regex to parse the class
"""
self._call_java('updateRegexClass', label, regex)
return self
[docs] def updateVocabClass(self, label, vocab, append=True):
"""Update existing class to correct, based on a vocabulary.
Parameters
----------
label : str
Label of the class
vocab : List[str]
Vocabulary as a list
append : bool, optional
Whether to append to the existing vocabulary, by default True
"""
self._call_java('updateVocabClass', label, vocab, append)
return self
[docs] def setCorrectSymbols(self, value):
"""Sets whether to correct special symbols or skip spell checking for
them.
Parameters
----------
value : bool
Whether to correct special symbols or skip spell checking for
them
"""
return self._set(correctSymbols=value)
[docs] def setCompareLowcase(self, value):
"""Sets whether to compare tokens in lower case with vocabulary.
Parameters
----------
value : bool
Whether to compare tokens in lower case with vocabulary.
"""
return self._set(compareLowcase=value)
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel",
java_model=None):
super(ContextSpellCheckerModel, self).__init__(
classname=classname,
java_model=java_model
)
@staticmethod
[docs] def pretrained(name="spellcheck_dl", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.
Parameters
----------
name : str, optional
Name of the pretrained model, by default "spellcheck_dl"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Optional remote address of the resource, by default None. Will use
Spark NLPs repositories otherwise.
Returns
-------
ContextSpellCheckerModel
The restored model
"""
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(ContextSpellCheckerModel, name, lang, remote_loc)