Source code for sparknlp.annotator.param.classifier_encoder
# Copyright 2017-2022 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sparknlp.common import *
from sparknlp.internal import ParamsGettersSetters
[docs]class ClassifierEncoder(ParamsGettersSetters):
maxEpochs = Param(Params._dummy(), "maxEpochs", "Maximum number of epochs to train", TypeConverters.toInt)
lr = Param(Params._dummy(), "lr", "Learning Rate", TypeConverters.toFloat)
batchSize = Param(Params._dummy(), "batchSize", "Batch size", TypeConverters.toInt)
labelColumn = Param(Params._dummy(),
"labelColumn",
"Column with label per each token",
typeConverter=TypeConverters.toString)
randomSeed = Param(Params._dummy(), "randomSeed", "Random seed", TypeConverters.toInt)
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
TypeConverters.toListInt)
[docs] def setMaxEpochs(self, epochs):
"""Sets maximum number of epochs to train, by default 30
Parameters
----------
epochs : int
Maximum number of epochs to train
"""
return self._set(maxEpochs=epochs)
[docs] def setLr(self, v):
"""Sets Learning Rate, by default 0.005
Parameters
----------
v : float
Learning Rate
"""
self._set(lr=v)
return self
[docs] def setBatchSize(self, v):
"""Sets batch size, by default 64.
Parameters
----------
v : int
Batch size
"""
self._set(batchSize=v)
return self
[docs] def setLabelColumn(self, value):
"""Sets name of column for data labels
Parameters
----------
value : str
Column for data labels
"""
return self._set(labelColumn=value)
[docs] def setRandomSeed(self, seed):
"""Sets random seed for shuffling
Parameters
----------
seed : int
Random seed for shuffling
"""
return self._set(randomSeed=seed)
[docs] def setConfigProtoBytes(self, b):
"""Sets configProto from tensorflow, serialized into byte array.
Parameters
----------
b : List[int]
ConfigProto from tensorflow, serialized into byte array
"""
return self._set(configProtoBytes=b)