class PromptAssembler extends Transformer with DefaultParamsWritable with HasOutputAnnotatorType with HasOutputAnnotationCol
Assembles a sequence of messages into a single string using a template. These strings can then be used as prompts for large language models.
This annotator expects an array of two-tuples as the type of the input column (one array of tuples per row). The first element of the tuples should be the role and the second element is the text of the message. Possible roles are "system", "user" and "assistant".
An assistant header can be added to the end of the generated string by using
setAddAssistant(true).
At the moment, this annotator uses llama.cpp as a backend to parse and apply the templates. llama.cpp uses basic pattern matching to determine the type of the template, then applies a basic version of the template to the messages. This means that more advanced templates are not supported.
For an extended example see the example notebook.
Example
// Batches (whole conversations) of arrays of messages val data: Seq[Seq[(String, String)]] = Seq( Seq( ("system", "You are a helpful assistant."), ("assistant", "Hello there, how can I help you?"), ("user", "I need help with organizing my room."))) val dataDF = data.toDF("messages") // llama3.1 val template = "{{- bos_token }} {%- if custom_tools is defined %} {%- set tools = custom_tools %} {%- " + "endif %} {%- if not tools_in_user_message is defined %} {%- set tools_in_user_message = true %} {%- " + "endif %} {%- if not date_string is defined %} {%- set date_string = \"26 Jul 2024\" %} {%- endif %} " + "{%- if not tools is defined %} {%- set tools = none %} {%- endif %} {#- This block extracts the " + "system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %}" + " {%- set system_message = messages[0]['content']|trim %} {%- set messages = messages[1:] %} {%- else" + " %} {%- set system_message = \"\" %} {%- endif %} {#- System message + builtin tools #} {{- " + "\"<|start_header_id|>system<|end_header_id|>\\n\\n\" }} {%- if builtin_tools is defined or tools is " + "not none %} {{- \"Environment: ipython\\n\" }} {%- endif %} {%- if builtin_tools is defined %} {{- " + "\"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}} " + "{%- endif %} {{- \"Cutting Knowledge Date: December 2023\\n\" }} {{- \"Today Date: \" + date_string " + "+ \"\\n\\n\" }} {%- if tools is not none and not tools_in_user_message %} {{- \"You have access to " + "the following functions. To call a function, please respond with JSON for a function call.\" }} {{- " + "'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its" + " value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in tools %} {{- t | tojson(indent=4) " + "}} {{- \"\\n\\n\" }} {%- endfor %} {%- endif %} {{- system_message }} {{- \"<|eot_id|>\" }} {#- " + "Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message " + "and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if " + "messages | length != 0 %} {%- set first_user_message = messages[0]['content']|trim %} {%- set " + "messages = messages[1:] %} {%- else %} {{- raise_exception(\"Cannot put tools in the first user " + "message when there's no first user message!\") }} {%- endif %} {{- " + "'<|start_header_id|>user<|end_header_id|>\\n\\n' -}} {{- \"Given the following functions, please " + "respond with a JSON for a function call \" }} {{- \"with its proper arguments that best answers the " + "given prompt.\\n\\n\" }} {{- 'Respond in the format {\"name\": function name, \"parameters\": " + "dictionary of argument name and its value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in " + "tools %} {{- t | tojson(indent=4) }} {{- \"\\n\\n\" }} {%- endfor %} {{- first_user_message + " + "\"<|eot_id|>\"}} {%- endif %} {%- for message in messages %} {%- if not (message.role == 'ipython' " + "or message.role == 'tool' or 'tool_calls' in message) %} {{- '<|start_header_id|>' + message['role']" + " + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }} {%- elif 'tool_calls' in " + "message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception(\"This model only " + "supports single tool-calls at once!\") }} {%- endif %} {%- set tool_call = message.tool_calls[0]" + ".function %} {%- if builtin_tools is defined and tool_call.name in builtin_tools %} {{- " + "'<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- \"<|python_tag|>\" + tool_call.name + " + "\".call(\" }} {%- for arg_name, arg_val in tool_call.arguments | items %} {{- arg_name + '=\"' + " + "arg_val + '\"' }} {%- if not loop.last %} {{- \", \" }} {%- endif %} {%- endfor %} {{- \")\" }} {%- " + "else %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- '{\"name\": \"' + " + "tool_call.name + '\", ' }} {{- '\"parameters\": ' }} {{- tool_call.arguments | tojson }} {{- \"}\" " + "}} {%- endif %} {%- if builtin_tools is defined %} {#- This means we're in ipython mode #} {{- " + "\"<|eom_id|>\" }} {%- else %} {{- \"<|eot_id|>\" }} {%- endif %} {%- elif message.role == \"tool\" " + "or message.role == \"ipython\" %} {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }} {%- " + "if message.content is mapping or message.content is iterable %} {{- message.content | tojson }} {%- " + "else %} {{- message.content }} {%- endif %} {{- \"<|eot_id|>\" }} {%- endif %} {%- endfor %} {%- if " + "add_generation_prompt %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }} {%- endif %} " val promptAssembler = new PromptAssembler() .setInputCol("messages") .setOutputCol("prompt") .setChatTemplate(template) promptAssembler.transform(dataDF).select("prompt.result").show(truncate = false) +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |result | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |[<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello there, how can I help you?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI need help with organizing my room.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n]| +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
- Grouped
- Alphabetic
- By Inheritance
- PromptAssembler
- HasOutputAnnotationCol
- HasOutputAnnotatorType
- DefaultParamsWritable
- MLWritable
- Transformer
- PipelineStage
- Logging
- Params
- Serializable
- Serializable
- Identifiable
- AnyRef
- Any
- Hide All
- Show All
- Public
- All
Instance Constructors
Type Members
- 
      
      
      
        
      
    
      
        
        type
      
      
        AnnotatorType = String
      
      
      - Definition Classes
- HasOutputAnnotatorType
 
Value Members
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        !=(arg0: Any): Boolean
      
      
      - Definition Classes
- AnyRef → Any
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        ##(): Int
      
      
      - Definition Classes
- AnyRef → Any
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        $[T](param: Param[T]): T
      
      
      - Attributes
- protected
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        ==(arg0: Any): Boolean
      
      
      - Definition Classes
- AnyRef → Any
 
-  val addAssistant: BooleanParam
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        asInstanceOf[T0]: T0
      
      
      - Definition Classes
- Any
 
-  val chatTemplate: Param[String]
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        clear(param: Param[_]): PromptAssembler.this.type
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        clone(): AnyRef
      
      
      - Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        copy(extra: ParamMap): Transformer
      
      
      - Definition Classes
- PromptAssembler → Transformer → PipelineStage → Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        copyValues[T <: Params](to: T, extra: ParamMap): T
      
      
      - Attributes
- protected
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        defaultCopy[T <: Params](extra: ParamMap): T
      
      
      - Attributes
- protected
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        eq(arg0: AnyRef): Boolean
      
      
      - Definition Classes
- AnyRef
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        equals(arg0: Any): Boolean
      
      
      - Definition Classes
- AnyRef → Any
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        explainParam(param: Param[_]): String
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        explainParams(): String
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        extractParamMap(): ParamMap
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        extractParamMap(extra: ParamMap): ParamMap
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        finalize(): Unit
      
      
      - Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( classOf[java.lang.Throwable] )
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        get[T](param: Param[T]): Option[T]
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        getAddAssistant: Boolean
      
      
      Whether to add an assistant header to the end of the generated string. Whether to add an assistant header to the end of the generated string. - returns
- Whether to add the assistant header 
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        getChatTemplate: String
      
      
      Gets the chat template to be used for the chat. Gets the chat template to be used for the chat. - returns
- The template to use 
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        getClass(): Class[_]
      
      
      - Definition Classes
- AnyRef → Any
- Annotations
- @native()
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        getDefault[T](param: Param[T]): Option[T]
      
      
      - Definition Classes
- Params
 
-  def getInputCol: String
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        getOrDefault[T](param: Param[T]): T
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        getOutputCol: String
      
      
      Gets annotation column name going to generate Gets annotation column name going to generate - Definition Classes
- HasOutputAnnotationCol
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        getParam(paramName: String): Param[Any]
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        hasDefault[T](param: Param[T]): Boolean
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        hasParam(paramName: String): Boolean
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        hashCode(): Int
      
      
      - Definition Classes
- AnyRef → Any
- Annotations
- @native()
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        initializeLogIfNecessary(isInterpreter: Boolean, silent: Boolean): Boolean
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        initializeLogIfNecessary(isInterpreter: Boolean): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
-  val inputCol: Param[String]
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        isDefined(param: Param[_]): Boolean
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        isInstanceOf[T0]: Boolean
      
      
      - Definition Classes
- Any
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        isSet(param: Param[_]): Boolean
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        isTraceEnabled(): Boolean
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        log: Logger
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logDebug(msg: ⇒ String, throwable: Throwable): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logDebug(msg: ⇒ String): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logError(msg: ⇒ String, throwable: Throwable): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logError(msg: ⇒ String): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logInfo(msg: ⇒ String, throwable: Throwable): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logInfo(msg: ⇒ String): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logName: String
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logTrace(msg: ⇒ String, throwable: Throwable): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logTrace(msg: ⇒ String): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logWarning(msg: ⇒ String, throwable: Throwable): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        logWarning(msg: ⇒ String): Unit
      
      
      - Attributes
- protected
- Definition Classes
- Logging
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        ne(arg0: AnyRef): Boolean
      
      
      - Definition Classes
- AnyRef
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        notify(): Unit
      
      
      - Definition Classes
- AnyRef
- Annotations
- @native()
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        notifyAll(): Unit
      
      
      - Definition Classes
- AnyRef
- Annotations
- @native()
 
- 
      
      
      
        
      
    
      
        
        val
      
      
        outputAnnotatorType: AnnotatorType
      
      
      - Definition Classes
- PromptAssembler → HasOutputAnnotatorType
 
- 
      
      
      
        
      
    
      
        final 
        val
      
      
        outputCol: Param[String]
      
      
      - Attributes
- protected
- Definition Classes
- HasOutputAnnotationCol
 
- 
      
      
      
        
      
    
      
        
        lazy val
      
      
        params: Array[Param[_]]
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        save(path: String): Unit
      
      
      - Definition Classes
- MLWritable
- Annotations
- @Since( "1.6.0" ) @throws( ... )
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        set(paramPair: ParamPair[_]): PromptAssembler.this.type
      
      
      - Attributes
- protected
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        set(param: String, value: Any): PromptAssembler.this.type
      
      
      - Attributes
- protected
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        set[T](param: Param[T], value: T): PromptAssembler.this.type
      
      
      - Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        setAddAssistant(value: Boolean): PromptAssembler.this.type
      
      
      Whether to add an assistant header to the end of the generated string. Whether to add an assistant header to the end of the generated string. - value
- Whether to add the assistant header 
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        setChatTemplate(value: String): PromptAssembler.this.type
      
      
      Sets the chat template to be used for the chat. Sets the chat template to be used for the chat. Should be something that llama.cpp can parse. - value
- The template to use 
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        setDefault(paramPairs: ParamPair[_]*): PromptAssembler.this.type
      
      
      - Attributes
- protected
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        setDefault[T](param: Param[T], value: T): PromptAssembler.this.type
      
      
      - Attributes
- protected[org.apache.spark.ml]
- Definition Classes
- Params
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        setInputCol(value: String): PromptAssembler.this.type
      
      
      Sets the input text column for processing 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        setOutputCol(value: String): PromptAssembler.this.type
      
      
      Overrides annotation column name when transforming Overrides annotation column name when transforming - Definition Classes
- HasOutputAnnotationCol
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        synchronized[T0](arg0: ⇒ T0): T0
      
      
      - Definition Classes
- AnyRef
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        toString(): String
      
      
      - Definition Classes
- Identifiable → AnyRef → Any
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        transform(dataset: Dataset[_]): DataFrame
      
      
      - Definition Classes
- PromptAssembler → Transformer
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame
      
      
      - Definition Classes
- Transformer
- Annotations
- @Since( "2.0.0" )
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame
      
      
      - Definition Classes
- Transformer
- Annotations
- @Since( "2.0.0" ) @varargs()
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        transformSchema(schema: StructType): StructType
      
      
      Adds the result Annotation type to the schema. Adds the result Annotation type to the schema. Requirement for pipeline transformation validation. It is called on fit() - Definition Classes
- PromptAssembler → PipelineStage
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        transformSchema(schema: StructType, logging: Boolean): StructType
      
      
      - Attributes
- protected
- Definition Classes
- PipelineStage
- Annotations
- @DeveloperApi()
 
- 
      
      
      
        
      
    
      
        
        val
      
      
        uid: String
      
      
      - Definition Classes
- PromptAssembler → Identifiable
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        wait(): Unit
      
      
      - Definition Classes
- AnyRef
- Annotations
- @throws( ... )
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        wait(arg0: Long, arg1: Int): Unit
      
      
      - Definition Classes
- AnyRef
- Annotations
- @throws( ... )
 
- 
      
      
      
        
      
    
      
        final 
        def
      
      
        wait(arg0: Long): Unit
      
      
      - Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
 
- 
      
      
      
        
      
    
      
        
        def
      
      
        write: MLWriter
      
      
      - Definition Classes
- DefaultParamsWritable → MLWritable