Source code for sparknlp.base.recursive_pipeline
# Copyright 2017-2022 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes for the RecursivePipeline."""
from pyspark import keyword_only
from pyspark.ml import PipelineModel, Estimator, Pipeline, Transformer
from pyspark.ml.wrapper import JavaEstimator
from sparknlp.common import AnnotatorProperties
from sparknlp.internal import RecursiveEstimator
from sparknlp.base import HasRecursiveTransform
[docs]class RecursivePipeline(Pipeline, JavaEstimator):
"""Recursive pipelines are Spark NLP specific pipelines that allow a Spark
ML Pipeline to know about itself on every Pipeline Stage task.
This allows annotators to utilize this same pipeline against external
resources to process them in the same way the user decides.
Only some of the annotators take advantage of this. RecursivePipeline
behaves exactly the same as normal Spark ML pipelines, so they can be used
with the same intention.
Examples
--------
>>> from sparknlp.annotator import *
>>> from sparknlp.base import *
>>> recursivePipeline = RecursivePipeline(stages=[
... documentAssembler,
... sentenceDetector,
... tokenizer,
... lemmatizer,
... finisher
... ])
"""
@keyword_only
def __init__(self, *args, **kwargs):
super(RecursivePipeline, self).__init__(*args, **kwargs)
self._java_obj = self._new_java_obj("com.johnsnowlabs.nlp.RecursivePipeline", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
def _fit(self, dataset):
stages = self.getStages()
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
raise TypeError(
"Cannot recognize a pipeline stage of type %s." % type(stage))
indexOfLastEstimator = -1
for i, stage in enumerate(stages):
if isinstance(stage, Estimator):
indexOfLastEstimator = i
transformers = []
for i, stage in enumerate(stages):
if i <= indexOfLastEstimator:
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset)
elif isinstance(stage, RecursiveEstimator):
model = stage.fit(dataset, pipeline=PipelineModel(transformers))
transformers.append(model)
if i < indexOfLastEstimator:
dataset = model.transform(dataset)
else:
model = stage.fit(dataset)
transformers.append(model)
if i < indexOfLastEstimator:
dataset = model.transform(dataset)
else:
transformers.append(stage)
return PipelineModel(transformers)
[docs]class RecursivePipelineModel(PipelineModel):
"""Fitted RecursivePipeline.
Behaves the same as a Spark PipelineModel does. Not intended to be
initialized by itself. To create a RecursivePipelineModel please fit data to
a :class:`.RecursivePipeline`.
"""
def __init__(self, pipeline_model):
super(PipelineModel, self).__init__()
self.stages = pipeline_model.stages
def _transform(self, dataset):
for t in self.stages:
if isinstance(t, HasRecursiveTransform):
# drops current stage from the recursive pipeline within
dataset = t.transform_recursive(dataset, PipelineModel(self.stages[:-1]))
elif isinstance(t, AnnotatorProperties) and t.getLazyAnnotator():
pass
else:
dataset = t.transform(dataset)
return dataset