Packages

class PromptAssembler extends Transformer with DefaultParamsWritable with HasOutputAnnotatorType with HasOutputAnnotationCol

Assembles a sequence of messages into a single string using a template. These strings can then be used as prompts for large language models.

This annotator expects an array of two-tuples as the type of the input column (one array of tuples per row). The first element of the tuples should be the role and the second element is the text of the message. Possible roles are "system", "user" and "assistant".

An assistant header can be added to the end of the generated string by using setAddAssistant(true).

At the moment, this annotator uses llama.cpp as a backend to parse and apply the templates. llama.cpp uses basic pattern matching to determine the type of the template, then applies a basic version of the template to the messages. This means that more advanced templates are not supported.

For an extended example see the example notebook.

Example

// Batches (whole conversations) of arrays of messages
val data: Seq[Seq[(String, String)]] = Seq(
  Seq(
    ("system", "You are a helpful assistant."),
    ("assistant", "Hello there, how can I help you?"),
    ("user", "I need help with organizing my room.")))

val dataDF = data.toDF("messages")

// llama3.1
val template =
  "{{- bos_token }} {%- if custom_tools is defined %} {%- set tools = custom_tools %} {%- " +
    "endif %} {%- if not tools_in_user_message is defined %} {%- set tools_in_user_message = true %} {%- " +
    "endif %} {%- if not date_string is defined %} {%- set date_string = \"26 Jul 2024\" %} {%- endif %} " +
    "{%- if not tools is defined %} {%- set tools = none %} {%- endif %} {#- This block extracts the " +
    "system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %}" +
    " {%- set system_message = messages[0]['content']|trim %} {%- set messages = messages[1:] %} {%- else" +
    " %} {%- set system_message = \"\" %} {%- endif %} {#- System message + builtin tools #} {{- " +
    "\"<|start_header_id|>system<|end_header_id|>\\n\\n\" }} {%- if builtin_tools is defined or tools is " +
    "not none %} {{- \"Environment: ipython\\n\" }} {%- endif %} {%- if builtin_tools is defined %} {{- " +
    "\"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}} " +
    "{%- endif %} {{- \"Cutting Knowledge Date: December 2023\\n\" }} {{- \"Today Date: \" + date_string " +
    "+ \"\\n\\n\" }} {%- if tools is not none and not tools_in_user_message %} {{- \"You have access to " +
    "the following functions. To call a function, please respond with JSON for a function call.\" }} {{- " +
    "'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its" +
    " value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in tools %} {{- t | tojson(indent=4) " +
    "}} {{- \"\\n\\n\" }} {%- endfor %} {%- endif %} {{- system_message }} {{- \"<|eot_id|>\" }} {#- " +
    "Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message " +
    "and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if " +
    "messages | length != 0 %} {%- set first_user_message = messages[0]['content']|trim %} {%- set " +
    "messages = messages[1:] %} {%- else %} {{- raise_exception(\"Cannot put tools in the first user " +
    "message when there's no first user message!\") }} {%- endif %} {{- " +
    "'<|start_header_id|>user<|end_header_id|>\\n\\n' -}} {{- \"Given the following functions, please " +
    "respond with a JSON for a function call \" }} {{- \"with its proper arguments that best answers the " +
    "given prompt.\\n\\n\" }} {{- 'Respond in the format {\"name\": function name, \"parameters\": " +
    "dictionary of argument name and its value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in " +
    "tools %} {{- t | tojson(indent=4) }} {{- \"\\n\\n\" }} {%- endfor %} {{- first_user_message + " +
    "\"<|eot_id|>\"}} {%- endif %} {%- for message in messages %} {%- if not (message.role == 'ipython' " +
    "or message.role == 'tool' or 'tool_calls' in message) %} {{- '<|start_header_id|>' + message['role']" +
    " + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }} {%- elif 'tool_calls' in " +
    "message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception(\"This model only " +
    "supports single tool-calls at once!\") }} {%- endif %} {%- set tool_call = message.tool_calls[0]" +
    ".function %} {%- if builtin_tools is defined and tool_call.name in builtin_tools %} {{- " +
    "'<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- \"<|python_tag|>\" + tool_call.name + " +
    "\".call(\" }} {%- for arg_name, arg_val in tool_call.arguments | items %} {{- arg_name + '=\"' + " +
    "arg_val + '\"' }} {%- if not loop.last %} {{- \", \" }} {%- endif %} {%- endfor %} {{- \")\" }} {%- " +
    "else %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- '{\"name\": \"' + " +
    "tool_call.name + '\", ' }} {{- '\"parameters\": ' }} {{- tool_call.arguments | tojson }} {{- \"}\" " +
    "}} {%- endif %} {%- if builtin_tools is defined %} {#- This means we're in ipython mode #} {{- " +
    "\"<|eom_id|>\" }} {%- else %} {{- \"<|eot_id|>\" }} {%- endif %} {%- elif message.role == \"tool\" " +
    "or message.role == \"ipython\" %} {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }} {%- " +
    "if message.content is mapping or message.content is iterable %} {{- message.content | tojson }} {%- " +
    "else %} {{- message.content }} {%- endif %} {{- \"<|eot_id|>\" }} {%- endif %} {%- endfor %} {%- if " +
    "add_generation_prompt %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }} {%- endif %} "

val promptAssembler = new PromptAssembler()
  .setInputCol("messages")
  .setOutputCol("prompt")
  .setChatTemplate(template)

promptAssembler.transform(dataDF).select("prompt.result").show(truncate = false)
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|result                                                                                                                                                                                                                                                                                                                      |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello there, how can I help you?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI need help with organizing my room.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n]|
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
Linear Supertypes
HasOutputAnnotationCol, HasOutputAnnotatorType, DefaultParamsWritable, MLWritable, Transformer, PipelineStage, Logging, Params, Serializable, Serializable, Identifiable, AnyRef, Any
Ordering
  1. Grouped
  2. Alphabetic
  3. By Inheritance
Inherited
  1. PromptAssembler
  2. HasOutputAnnotationCol
  3. HasOutputAnnotatorType
  4. DefaultParamsWritable
  5. MLWritable
  6. Transformer
  7. PipelineStage
  8. Logging
  9. Params
  10. Serializable
  11. Serializable
  12. Identifiable
  13. AnyRef
  14. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new PromptAssembler()
  2. new PromptAssembler(uid: String)

    uid

    required uid for storing annotator to disk

Type Members

  1. type AnnotatorType = String
    Definition Classes
    HasOutputAnnotatorType

Value Members

  1. final def !=(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int
    Definition Classes
    AnyRef → Any
  3. final def $[T](param: Param[T]): T
    Attributes
    protected
    Definition Classes
    Params
  4. final def ==(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  5. val addAssistant: BooleanParam
  6. final def asInstanceOf[T0]: T0
    Definition Classes
    Any
  7. val chatTemplate: Param[String]
  8. final def clear(param: Param[_]): PromptAssembler.this.type
    Definition Classes
    Params
  9. def clone(): AnyRef
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  10. def copy(extra: ParamMap): Transformer
    Definition Classes
    PromptAssembler → Transformer → PipelineStage → Params
  11. def copyValues[T <: Params](to: T, extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  12. final def defaultCopy[T <: Params](extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  13. final def eq(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  14. def equals(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  15. def explainParam(param: Param[_]): String
    Definition Classes
    Params
  16. def explainParams(): String
    Definition Classes
    Params
  17. final def extractParamMap(): ParamMap
    Definition Classes
    Params
  18. final def extractParamMap(extra: ParamMap): ParamMap
    Definition Classes
    Params
  19. def finalize(): Unit
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  20. final def get[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  21. def getAddAssistant: Boolean

    Whether to add an assistant header to the end of the generated string.

    Whether to add an assistant header to the end of the generated string.

    returns

    Whether to add the assistant header

  22. def getChatTemplate: String

    Gets the chat template to be used for the chat.

    Gets the chat template to be used for the chat.

    returns

    The template to use

  23. final def getClass(): Class[_]
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  24. final def getDefault[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  25. def getInputCol: String
  26. final def getOrDefault[T](param: Param[T]): T
    Definition Classes
    Params
  27. final def getOutputCol: String

    Gets annotation column name going to generate

    Gets annotation column name going to generate

    Definition Classes
    HasOutputAnnotationCol
  28. def getParam(paramName: String): Param[Any]
    Definition Classes
    Params
  29. final def hasDefault[T](param: Param[T]): Boolean
    Definition Classes
    Params
  30. def hasParam(paramName: String): Boolean
    Definition Classes
    Params
  31. def hashCode(): Int
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  32. def initializeLogIfNecessary(isInterpreter: Boolean, silent: Boolean): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  33. def initializeLogIfNecessary(isInterpreter: Boolean): Unit
    Attributes
    protected
    Definition Classes
    Logging
  34. val inputCol: Param[String]
  35. final def isDefined(param: Param[_]): Boolean
    Definition Classes
    Params
  36. final def isInstanceOf[T0]: Boolean
    Definition Classes
    Any
  37. final def isSet(param: Param[_]): Boolean
    Definition Classes
    Params
  38. def isTraceEnabled(): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  39. def log: Logger
    Attributes
    protected
    Definition Classes
    Logging
  40. def logDebug(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  41. def logDebug(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  42. def logError(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  43. def logError(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  44. def logInfo(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  45. def logInfo(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  46. def logName: String
    Attributes
    protected
    Definition Classes
    Logging
  47. def logTrace(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  48. def logTrace(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  49. def logWarning(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  50. def logWarning(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  51. final def ne(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  52. final def notify(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  53. final def notifyAll(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  54. val outputAnnotatorType: AnnotatorType
  55. final val outputCol: Param[String]
    Attributes
    protected
    Definition Classes
    HasOutputAnnotationCol
  56. lazy val params: Array[Param[_]]
    Definition Classes
    Params
  57. def save(path: String): Unit
    Definition Classes
    MLWritable
    Annotations
    @Since( "1.6.0" ) @throws( ... )
  58. final def set(paramPair: ParamPair[_]): PromptAssembler.this.type
    Attributes
    protected
    Definition Classes
    Params
  59. final def set(param: String, value: Any): PromptAssembler.this.type
    Attributes
    protected
    Definition Classes
    Params
  60. final def set[T](param: Param[T], value: T): PromptAssembler.this.type
    Definition Classes
    Params
  61. def setAddAssistant(value: Boolean): PromptAssembler.this.type

    Whether to add an assistant header to the end of the generated string.

    Whether to add an assistant header to the end of the generated string.

    value

    Whether to add the assistant header

  62. def setChatTemplate(value: String): PromptAssembler.this.type

    Sets the chat template to be used for the chat.

    Sets the chat template to be used for the chat. Should be something that llama.cpp can parse.

    value

    The template to use

  63. final def setDefault(paramPairs: ParamPair[_]*): PromptAssembler.this.type
    Attributes
    protected
    Definition Classes
    Params
  64. final def setDefault[T](param: Param[T], value: T): PromptAssembler.this.type
    Attributes
    protected[org.apache.spark.ml]
    Definition Classes
    Params
  65. def setInputCol(value: String): PromptAssembler.this.type

    Sets the input text column for processing

  66. final def setOutputCol(value: String): PromptAssembler.this.type

    Overrides annotation column name when transforming

    Overrides annotation column name when transforming

    Definition Classes
    HasOutputAnnotationCol
  67. final def synchronized[T0](arg0: ⇒ T0): T0
    Definition Classes
    AnyRef
  68. def toString(): String
    Definition Classes
    Identifiable → AnyRef → Any
  69. def transform(dataset: Dataset[_]): DataFrame
    Definition Classes
    PromptAssembler → Transformer
  70. def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame
    Definition Classes
    Transformer
    Annotations
    @Since( "2.0.0" )
  71. def transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame
    Definition Classes
    Transformer
    Annotations
    @Since( "2.0.0" ) @varargs()
  72. final def transformSchema(schema: StructType): StructType

    Adds the result Annotation type to the schema.

    Adds the result Annotation type to the schema.

    Requirement for pipeline transformation validation. It is called on fit()

    Definition Classes
    PromptAssembler → PipelineStage
  73. def transformSchema(schema: StructType, logging: Boolean): StructType
    Attributes
    protected
    Definition Classes
    PipelineStage
    Annotations
    @DeveloperApi()
  74. val uid: String
    Definition Classes
    PromptAssembler → Identifiable
  75. final def wait(): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  76. final def wait(arg0: Long, arg1: Int): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  77. final def wait(arg0: Long): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  78. def write: MLWriter
    Definition Classes
    DefaultParamsWritable → MLWritable

Inherited from HasOutputAnnotationCol

Inherited from HasOutputAnnotatorType

Inherited from DefaultParamsWritable

Inherited from MLWritable

Inherited from Transformer

Inherited from PipelineStage

Inherited from Logging

Inherited from Params

Inherited from Serializable

Inherited from Serializable

Inherited from Identifiable

Inherited from AnyRef

Inherited from Any

Members

Parameter setters