Source code for sparknlp.annotator.classifier_dl.mpnet_for_question_answering

#  Copyright 2017-2022 John Snow Labs
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from sparknlp.common import *


[docs]class MPNetForQuestionAnswering(AnnotatorModel, HasCaseSensitiveProperties, HasBatchedAnnotate, HasEngine, HasMaxSentenceLengthLimit): """MPNetForQuestionAnswering can load MPNet Models with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits and span end logits). Pretrained models can be loaded with :meth:`.pretrained` of the companion object: >>> spanClassifier = MPNetForQuestionAnswering.pretrained() \\ ... .setInputCols(["document_question", "document_context"]) \\ ... .setOutputCol("answer") The default model is ``"mpnet_base_question_answering_squad2"``, if no name is provided. For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Question+Answering>`__. To see which models are compatible and how to import them see `Import Transformers into Spark NLP 🚀 <https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_. ====================== ====================== Input Annotation types Output Annotation type ====================== ====================== ``DOCUMENT, DOCUMENT`` ``CHUNK`` ====================== ====================== Parameters ---------- batchSize Batch size. Large values allows faster processing but requires more memory, by default 8 caseSensitive Whether to ignore case in tokens for embeddings matching, by default False maxSentenceLength Max sentence length to process, by default 128 Examples -------- >>> import sparknlp >>> from sparknlp.base import * >>> from sparknlp.annotator import * >>> from pyspark.ml import Pipeline >>> documentAssembler = MultiDocumentAssembler() \\ ... .setInputCols(["question", "context"]) \\ ... .setOutputCol(["document_question", "document_context"]) >>> spanClassifier = MPNetForQuestionAnswering.pretrained() \\ ... .setInputCols(["document_question", "document_context"]) \\ ... .setOutputCol("answer") \\ ... .setCaseSensitive(False) >>> pipeline = Pipeline().setStages([ ... documentAssembler, ... spanClassifier ... ]) >>> data = spark.createDataFrame([["What's my name?", "My name is Clara and I live in Berkeley."]]).toDF("question", "context") >>> result = pipeline.fit(data).transform(data) >>> result.select("answer.result").show(truncate=False) +--------------------+ |result | +--------------------+ |[Clara] | +--------------------+ """ name = "MPNetForQuestionAnswering" inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT] outputAnnotatorType = AnnotatorType.CHUNK @keyword_only def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.MPNetForQuestionAnswering", java_model=None): super(MPNetForQuestionAnswering, self).__init__( classname=classname, java_model=java_model ) self._setDefault( batchSize=8, maxSentenceLength=384, caseSensitive=False ) @staticmethod
[docs] def loadSavedModel(folder, spark_session): """Loads a locally saved model. Parameters ---------- folder : str Folder of the saved model spark_session : pyspark.sql.SparkSession The current SparkSession Returns ------- MPNetForQuestionAnswering The restored model """ from sparknlp.internal import _MPNetForQuestionAnsweringLoader jModel = _MPNetForQuestionAnsweringLoader(folder, spark_session._jsparkSession)._java_obj return MPNetForQuestionAnswering(java_model=jModel)
@staticmethod
[docs] def pretrained(name="mpnet_base_question_answering_squad2", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters ---------- name : str, optional Name of the pretrained model, by default "mpnet_base_question_answering_squad2" lang : str, optional Language of the pretrained model, by default "en" remote_loc : str, optional Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise. Returns ------- MPNetForQuestionAnswering The restored model """ from sparknlp.pretrained import ResourceDownloader return ResourceDownloader.downloadModel(MPNetForQuestionAnswering, name, lang, remote_loc)