class PromptAssembler extends Transformer with DefaultParamsWritable with HasOutputAnnotatorType with HasOutputAnnotationCol
Assembles a sequence of messages into a single string using a template. These strings can then be used as prompts for large language models.
This annotator expects an array of two-tuples as the type of the input column (one array of tuples per row). The first element of the tuples should be the role and the second element is the text of the message. Possible roles are "system", "user" and "assistant".
An assistant header can be added to the end of the generated string by using
setAddAssistant(true)
.
At the moment, this annotator uses llama.cpp as a backend to parse and apply the templates. llama.cpp uses basic pattern matching to determine the type of the template, then applies a basic version of the template to the messages. This means that more advanced templates are not supported.
For an extended example see the example notebook.
Example
// Batches (whole conversations) of arrays of messages val data: Seq[Seq[(String, String)]] = Seq( Seq( ("system", "You are a helpful assistant."), ("assistant", "Hello there, how can I help you?"), ("user", "I need help with organizing my room."))) val dataDF = data.toDF("messages") // llama3.1 val template = "{{- bos_token }} {%- if custom_tools is defined %} {%- set tools = custom_tools %} {%- " + "endif %} {%- if not tools_in_user_message is defined %} {%- set tools_in_user_message = true %} {%- " + "endif %} {%- if not date_string is defined %} {%- set date_string = \"26 Jul 2024\" %} {%- endif %} " + "{%- if not tools is defined %} {%- set tools = none %} {%- endif %} {#- This block extracts the " + "system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %}" + " {%- set system_message = messages[0]['content']|trim %} {%- set messages = messages[1:] %} {%- else" + " %} {%- set system_message = \"\" %} {%- endif %} {#- System message + builtin tools #} {{- " + "\"<|start_header_id|>system<|end_header_id|>\\n\\n\" }} {%- if builtin_tools is defined or tools is " + "not none %} {{- \"Environment: ipython\\n\" }} {%- endif %} {%- if builtin_tools is defined %} {{- " + "\"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}} " + "{%- endif %} {{- \"Cutting Knowledge Date: December 2023\\n\" }} {{- \"Today Date: \" + date_string " + "+ \"\\n\\n\" }} {%- if tools is not none and not tools_in_user_message %} {{- \"You have access to " + "the following functions. To call a function, please respond with JSON for a function call.\" }} {{- " + "'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its" + " value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in tools %} {{- t | tojson(indent=4) " + "}} {{- \"\\n\\n\" }} {%- endfor %} {%- endif %} {{- system_message }} {{- \"<|eot_id|>\" }} {#- " + "Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message " + "and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if " + "messages | length != 0 %} {%- set first_user_message = messages[0]['content']|trim %} {%- set " + "messages = messages[1:] %} {%- else %} {{- raise_exception(\"Cannot put tools in the first user " + "message when there's no first user message!\") }} {%- endif %} {{- " + "'<|start_header_id|>user<|end_header_id|>\\n\\n' -}} {{- \"Given the following functions, please " + "respond with a JSON for a function call \" }} {{- \"with its proper arguments that best answers the " + "given prompt.\\n\\n\" }} {{- 'Respond in the format {\"name\": function name, \"parameters\": " + "dictionary of argument name and its value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in " + "tools %} {{- t | tojson(indent=4) }} {{- \"\\n\\n\" }} {%- endfor %} {{- first_user_message + " + "\"<|eot_id|>\"}} {%- endif %} {%- for message in messages %} {%- if not (message.role == 'ipython' " + "or message.role == 'tool' or 'tool_calls' in message) %} {{- '<|start_header_id|>' + message['role']" + " + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }} {%- elif 'tool_calls' in " + "message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception(\"This model only " + "supports single tool-calls at once!\") }} {%- endif %} {%- set tool_call = message.tool_calls[0]" + ".function %} {%- if builtin_tools is defined and tool_call.name in builtin_tools %} {{- " + "'<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- \"<|python_tag|>\" + tool_call.name + " + "\".call(\" }} {%- for arg_name, arg_val in tool_call.arguments | items %} {{- arg_name + '=\"' + " + "arg_val + '\"' }} {%- if not loop.last %} {{- \", \" }} {%- endif %} {%- endfor %} {{- \")\" }} {%- " + "else %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- '{\"name\": \"' + " + "tool_call.name + '\", ' }} {{- '\"parameters\": ' }} {{- tool_call.arguments | tojson }} {{- \"}\" " + "}} {%- endif %} {%- if builtin_tools is defined %} {#- This means we're in ipython mode #} {{- " + "\"<|eom_id|>\" }} {%- else %} {{- \"<|eot_id|>\" }} {%- endif %} {%- elif message.role == \"tool\" " + "or message.role == \"ipython\" %} {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }} {%- " + "if message.content is mapping or message.content is iterable %} {{- message.content | tojson }} {%- " + "else %} {{- message.content }} {%- endif %} {{- \"<|eot_id|>\" }} {%- endif %} {%- endfor %} {%- if " + "add_generation_prompt %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }} {%- endif %} " val promptAssembler = new PromptAssembler() .setInputCol("messages") .setOutputCol("prompt") .setChatTemplate(template) promptAssembler.transform(dataDF).select("prompt.result").show(truncate = false) +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |result | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |[<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello there, how can I help you?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI need help with organizing my room.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n]| +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
- Grouped
- Alphabetic
- By Inheritance
- PromptAssembler
- HasOutputAnnotationCol
- HasOutputAnnotatorType
- DefaultParamsWritable
- MLWritable
- Transformer
- PipelineStage
- Logging
- Params
- Serializable
- Serializable
- Identifiable
- AnyRef
- Any
- Hide All
- Show All
- Public
- All
Instance Constructors
Type Members
-
type
AnnotatorType = String
- Definition Classes
- HasOutputAnnotatorType
Value Members
-
final
def
!=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
-
final
def
##(): Int
- Definition Classes
- AnyRef → Any
-
final
def
$[T](param: Param[T]): T
- Attributes
- protected
- Definition Classes
- Params
-
final
def
==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- val addAssistant: BooleanParam
-
final
def
asInstanceOf[T0]: T0
- Definition Classes
- Any
- val chatTemplate: Param[String]
-
final
def
clear(param: Param[_]): PromptAssembler.this.type
- Definition Classes
- Params
-
def
clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
-
def
copy(extra: ParamMap): Transformer
- Definition Classes
- PromptAssembler → Transformer → PipelineStage → Params
-
def
copyValues[T <: Params](to: T, extra: ParamMap): T
- Attributes
- protected
- Definition Classes
- Params
-
final
def
defaultCopy[T <: Params](extra: ParamMap): T
- Attributes
- protected
- Definition Classes
- Params
-
final
def
eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
-
def
equals(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
-
def
explainParam(param: Param[_]): String
- Definition Classes
- Params
-
def
explainParams(): String
- Definition Classes
- Params
-
final
def
extractParamMap(): ParamMap
- Definition Classes
- Params
-
final
def
extractParamMap(extra: ParamMap): ParamMap
- Definition Classes
- Params
-
def
finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( classOf[java.lang.Throwable] )
-
final
def
get[T](param: Param[T]): Option[T]
- Definition Classes
- Params
-
def
getAddAssistant: Boolean
Whether to add an assistant header to the end of the generated string.
Whether to add an assistant header to the end of the generated string.
- returns
Whether to add the assistant header
-
def
getChatTemplate: String
Gets the chat template to be used for the chat.
Gets the chat template to be used for the chat.
- returns
The template to use
-
final
def
getClass(): Class[_]
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
-
final
def
getDefault[T](param: Param[T]): Option[T]
- Definition Classes
- Params
- def getInputCol: String
-
final
def
getOrDefault[T](param: Param[T]): T
- Definition Classes
- Params
-
final
def
getOutputCol: String
Gets annotation column name going to generate
Gets annotation column name going to generate
- Definition Classes
- HasOutputAnnotationCol
-
def
getParam(paramName: String): Param[Any]
- Definition Classes
- Params
-
final
def
hasDefault[T](param: Param[T]): Boolean
- Definition Classes
- Params
-
def
hasParam(paramName: String): Boolean
- Definition Classes
- Params
-
def
hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
-
def
initializeLogIfNecessary(isInterpreter: Boolean, silent: Boolean): Boolean
- Attributes
- protected
- Definition Classes
- Logging
-
def
initializeLogIfNecessary(isInterpreter: Boolean): Unit
- Attributes
- protected
- Definition Classes
- Logging
- val inputCol: Param[String]
-
final
def
isDefined(param: Param[_]): Boolean
- Definition Classes
- Params
-
final
def
isInstanceOf[T0]: Boolean
- Definition Classes
- Any
-
final
def
isSet(param: Param[_]): Boolean
- Definition Classes
- Params
-
def
isTraceEnabled(): Boolean
- Attributes
- protected
- Definition Classes
- Logging
-
def
log: Logger
- Attributes
- protected
- Definition Classes
- Logging
-
def
logDebug(msg: ⇒ String, throwable: Throwable): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logDebug(msg: ⇒ String): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logError(msg: ⇒ String, throwable: Throwable): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logError(msg: ⇒ String): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logInfo(msg: ⇒ String, throwable: Throwable): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logInfo(msg: ⇒ String): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logName: String
- Attributes
- protected
- Definition Classes
- Logging
-
def
logTrace(msg: ⇒ String, throwable: Throwable): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logTrace(msg: ⇒ String): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logWarning(msg: ⇒ String, throwable: Throwable): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
def
logWarning(msg: ⇒ String): Unit
- Attributes
- protected
- Definition Classes
- Logging
-
final
def
ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
-
final
def
notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
-
final
def
notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
-
val
outputAnnotatorType: AnnotatorType
- Definition Classes
- PromptAssembler → HasOutputAnnotatorType
-
final
val
outputCol: Param[String]
- Attributes
- protected
- Definition Classes
- HasOutputAnnotationCol
-
lazy val
params: Array[Param[_]]
- Definition Classes
- Params
-
def
save(path: String): Unit
- Definition Classes
- MLWritable
- Annotations
- @Since( "1.6.0" ) @throws( ... )
-
final
def
set(paramPair: ParamPair[_]): PromptAssembler.this.type
- Attributes
- protected
- Definition Classes
- Params
-
final
def
set(param: String, value: Any): PromptAssembler.this.type
- Attributes
- protected
- Definition Classes
- Params
-
final
def
set[T](param: Param[T], value: T): PromptAssembler.this.type
- Definition Classes
- Params
-
def
setAddAssistant(value: Boolean): PromptAssembler.this.type
Whether to add an assistant header to the end of the generated string.
Whether to add an assistant header to the end of the generated string.
- value
Whether to add the assistant header
-
def
setChatTemplate(value: String): PromptAssembler.this.type
Sets the chat template to be used for the chat.
Sets the chat template to be used for the chat. Should be something that llama.cpp can parse.
- value
The template to use
-
final
def
setDefault(paramPairs: ParamPair[_]*): PromptAssembler.this.type
- Attributes
- protected
- Definition Classes
- Params
-
final
def
setDefault[T](param: Param[T], value: T): PromptAssembler.this.type
- Attributes
- protected[org.apache.spark.ml]
- Definition Classes
- Params
-
def
setInputCol(value: String): PromptAssembler.this.type
Sets the input text column for processing
-
final
def
setOutputCol(value: String): PromptAssembler.this.type
Overrides annotation column name when transforming
Overrides annotation column name when transforming
- Definition Classes
- HasOutputAnnotationCol
-
final
def
synchronized[T0](arg0: ⇒ T0): T0
- Definition Classes
- AnyRef
-
def
toString(): String
- Definition Classes
- Identifiable → AnyRef → Any
-
def
transform(dataset: Dataset[_]): DataFrame
- Definition Classes
- PromptAssembler → Transformer
-
def
transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame
- Definition Classes
- Transformer
- Annotations
- @Since( "2.0.0" )
-
def
transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame
- Definition Classes
- Transformer
- Annotations
- @Since( "2.0.0" ) @varargs()
-
final
def
transformSchema(schema: StructType): StructType
Adds the result Annotation type to the schema.
Adds the result Annotation type to the schema.
Requirement for pipeline transformation validation. It is called on fit()
- Definition Classes
- PromptAssembler → PipelineStage
-
def
transformSchema(schema: StructType, logging: Boolean): StructType
- Attributes
- protected
- Definition Classes
- PipelineStage
- Annotations
- @DeveloperApi()
-
val
uid: String
- Definition Classes
- PromptAssembler → Identifiable
-
final
def
wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... )
-
final
def
wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... )
-
final
def
wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
-
def
write: MLWriter
- Definition Classes
- DefaultParamsWritable → MLWritable