Source code for sparknlp.annotator.ner.zero_shot_ner_model

#  Copyright 2017-2023 John Snow Labs
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from sparknlp.common import *
from sparknlp.annotator.classifier_dl import RoBertaForQuestionAnswering


[docs]class ZeroShotNerModel(RoBertaForQuestionAnswering, HasEngine): """ZeroShotNerModel implements zero shot named entity recognition by utilizing RoBERTa transformer models fine tuned on a question answering task. Its input is a list of document annotations and it automatically generates questions which are used to recognize entities. The definitions of entities is given by a dictionary structures, specifying a set of questions for each entity. The model is based on RoBertaForQuestionAnswering. For more extended examples see the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/named-entity-recognition/ZeroShot_NER.ipynb>`__. Pretrained models can be loaded with ``pretrained`` of the companion object: .. code-block:: python zeroShotNer = ZeroShotNerModel.pretrained() \\ .setInputCols(["document"]) \\ .setOutputCol("zer_shot_ner") ====================== ====================== Input Annotation types Output Annotation type ====================== ====================== ``DOCUMENT, TOKEN`` ``NAMED_ENTITY`` ====================== ====================== Parameters ---------- entityDefinitions A dictionary with definitions of named entities. The keys of dictionary are the entity labels and the values are lists of questions. For example: { "CITY": ["Which city?", "Which town?"], "NAME": ["What is her name?", "What is his name?"]} predictionThreshold Minimal confidence score to encode an entity (Default: 0.01f) ignoreEntities A list of entity labels which are discarded from the output. References ---------- `RoBERTa: A Robustly Optimized BERT Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ : for details about the RoBERTa transformer :class:`.RoBertaForQuestionAnswering` : for the SparkNLP implementation of RoBERTa question answering Examples -------- >>> document_assembler = DocumentAssembler() \\ ... .setInputCol("text") \\ ... .setOutputCol("document") >>> sentence_detector = SentenceDetector() \\ ... .setInputCols(["document"]) \\ ... .setOutputCol("sentence") >>> tokenizer = Tokenizer() \\ ... .setInputCols(["sentence"]) \\ ... .setOutputCol("token") >>> zero_shot_ner = ZeroShotNerModel() \\ ... .pretrained() \\ ... .setEntityDefinitions( ... { ... "NAME": ["What is his name?", "What is my name?", "What is her name?"], ... "CITY": ["Which city?", "Which is the city?"] ... }) \\ ... .setInputCols(["sentence", "token"]) \\ ... .setOutputCol("zero_shot_ner") \\ >>> data = spark.createDataFrame( ... [["My name is Clara, I live in New York and Hellen lives in Paris."]] ... ).toDF("text") >>> Pipeline() \\ ... .setStages([document_assembler, sentence_detector, tokenizer, zero_shot_ner]) \\ ... .fit(data) \\ ... .transform(data) \\ ... .selectExpr("document", "explode(zero_shot_ner) AS entity") \\ ... .select( ... "document.result", ... "entity.result", ... "entity.metadata.word", ... "entity.metadata.confidence", ... "entity.metadata.question") \\ ... .show(truncate=False) +-----------------------------------------------------------------+------+------+----------+------------------+ |result |result|word |confidence|question | +-----------------------------------------------------------------+------+------+----------+------------------+ |[My name is Clara, I live in New York and Hellen lives in Paris.]|B-CITY|Paris |0.5328949 |Which is the city?| |[My name is Clara, I live in New York and Hellen lives in Paris.]|B-NAME|Clara |0.9360068 |What is my name? | |[My name is Clara, I live in New York and Hellen lives in Paris.]|B-CITY|New |0.83294415|Which city? | |[My name is Clara, I live in New York and Hellen lives in Paris.]|I-CITY|York |0.83294415|Which city? | |[My name is Clara, I live in New York and Hellen lives in Paris.]|B-NAME|Hellen|0.45366877|What is her name? | +-----------------------------------------------------------------+------+------+----------+------------------+ """ inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN] outputAnnotatorType = AnnotatorType.NAMED_ENTITY name = "ZeroShotNerModel" predictionThreshold = Param(Params._dummy(), "predictionThreshold", "Minimal confidence score to encode an entity (default is 0.1)", TypeConverters.toFloat) ignoreEntities = Param(Params._dummy(), "ignoreEntities", "List of entities to ignore", TypeConverters.toListString)
[docs] def setPredictionThreshold(self, threshold): """Sets the minimal confidence score to encode an entity Parameters ---------- threshold : float minimal confidence score to encode an entity (default is 0.1) """ return self._set(predictionThreshold=threshold)
[docs] def setEntityDefinitions(self, definitions): """Set entity definitions Parameters ---------- definitions : dict[str, list[str]] """ self._call_java("setEntityDefinitions", definitions) return self
[docs] def getClasses(self): """ Returns the list of entities which are recognized """ return self._call_java("getEntities")
@keyword_only def __init__(self, classname="com.johnsnowlabs.nlp.annotators.ner.dl.ZeroShotNerModel", java_model=None): super(ZeroShotNerModel, self).__init__( classname=classname, java_model=java_model ) self._setDefault( predictionThreshold=0.01, ignoreEntities=[] ) @staticmethod
[docs] def pretrained(name="zero_shot_ner_roberta", lang="en", remote_loc=None): from sparknlp.pretrained import ResourceDownloader return ResourceDownloader.downloadModel(ZeroShotNerModel, name, lang, remote_loc, j_dwn='PythonResourceDownloader')
@staticmethod
[docs] def load(path): from sparknlp.internal import _RobertaQAToZeroShotNerLoader jModel = _RobertaQAToZeroShotNerLoader(path)._java_obj return ZeroShotNerModel(java_model=jModel)