Source code for sparknlp.annotator.classifier_dl.bert_for_multiple_choice
# Copyright 2017-2024 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sparknlp.common import *
[docs]class BertForMultipleChoice(AnnotatorModel,
HasCaseSensitiveProperties,
HasBatchedAnnotate,
HasEngine,
HasMaxSentenceLengthLimit):
"""BertForMultipleChoice can load BERT Models with a multiple choice classification head on top
(a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
Pretrained models can be loaded with :meth:`.pretrained` of the companion
object:
>>> spanClassifier = BertForMultipleChoice.pretrained() \\
... .setInputCols(["document_question", "document_context"]) \\
... .setOutputCol("answer")
The default model is ``"bert_base_uncased_multiple_choice"``, if no name is
provided.
For available pretrained models please see the `Models Hub
<https://sparknlp.org/models?task=Multiple+Choice>`__.
To see which models are compatible and how to import them see
`Import Transformers into Spark NLP 🚀
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
====================== ======================
Input Annotation types Output Annotation type
====================== ======================
``DOCUMENT, DOCUMENT`` ``CHUNK``
====================== ======================
Parameters
----------
batchSize
Batch size. Large values allows faster processing but requires more
memory, by default 8
caseSensitive
Whether to ignore case in tokens for embeddings matching, by default
False
maxSentenceLength
Max sentence length to process, by default 512
Examples
--------
>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
>>> documentAssembler = MultiDocumentAssembler() \\
... .setInputCols(["question", "context"]) \\
... .setOutputCols(["document_question", "document_context"])
>>> questionAnswering = BertForMultipleChoice.pretrained() \\
... .setInputCols(["document_question", "document_context"]) \\
... .setOutputCol("answer") \\
... .setCaseSensitive(False)
>>> pipeline = Pipeline().setStages([
... documentAssembler,
... questionAnswering
... ])
>>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context")
>>> result = pipeline.fit(data).transform(data)
>>> result.select("answer.result").show(truncate=False)
+--------------------+
|result |
+--------------------+
|[France] |
+--------------------+
"""
name = "BertForMultipleChoice"
inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]
outputAnnotatorType = AnnotatorType.CHUNK
choicesDelimiter = Param(Params._dummy(),
"choicesDelimiter",
"Delimiter character use to split the choices",
TypeConverters.toString)
[docs] def setChoicesDelimiter(self, value):
"""Sets delimiter character use to split the choices
Parameters
----------
value : string
Delimiter character use to split the choices
"""
return self._set(caseSensitive=value)
@keyword_only
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice",
java_model=None):
super(BertForMultipleChoice, self).__init__(
classname=classname,
java_model=java_model
)
self._setDefault(
batchSize=4,
maxSentenceLength=512,
caseSensitive=False,
choicesDelimiter = ","
)
@staticmethod
[docs] def loadSavedModel(folder, spark_session):
"""Loads a locally saved model.
Parameters
----------
folder : str
Folder of the saved model
spark_session : pyspark.sql.SparkSession
The current SparkSession
Returns
-------
BertForQuestionAnswering
The restored model
"""
from sparknlp.internal import _BertMultipleChoiceLoader
jModel = _BertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj
return BertForMultipleChoice(java_model=jModel)
@staticmethod
[docs] def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.
Parameters
----------
name : str, optional
Name of the pretrained model, by default
"bert_base_uncased_multiple_choice"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Optional remote address of the resource, by default None. Will use
Spark NLPs repositories otherwise.
Returns
-------
BertForQuestionAnswering
The restored model
"""
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc)