diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 58b233f00a..fd5d396767 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -1,7 +1,7 @@ import json import os import time -from typing import Any +from typing import Any, Tuple from typing import Dict from typing import List from typing import Optional @@ -66,10 +66,10 @@ from ddtrace.llmobs._utils import _get_span_name from ddtrace.llmobs._utils import _is_evaluation_span from ddtrace.llmobs._utils import safe_json -from ddtrace.llmobs._utils import validate_prompt from ddtrace.llmobs._writer import LLMObsEvalMetricWriter from ddtrace.llmobs._writer import LLMObsSpanWriter from ddtrace.llmobs.utils import Documents +from ddtrace.llmobs.utils import Prompt from ddtrace.llmobs.utils import ExportedLLMObsSpan from ddtrace.llmobs.utils import Messages from ddtrace.propagation.http import HTTPPropagator @@ -463,7 +463,7 @@ def _tag_span_links(self, span, span_links): @classmethod def annotation_context( - cls, tags: Optional[Dict[str, Any]] = None, prompt: Optional[dict] = None, name: Optional[str] = None + cls, tags: Optional[Dict[str, Any]] = None, prompt: Optional[Prompt] = None, name: Optional[str] = None ) -> AnnotationContext: """ Sets specified attributes on all LLMObs spans created while the returned AnnotationContext is active. @@ -805,11 +805,30 @@ def retrieval( log.warning(SPAN_START_WHILE_DISABLED_WARNING) return cls._instance._start_span("retrieval", name=name, session_id=session_id, ml_app=ml_app) + @classmethod + def prompt_context(cls, + name: str, + version: Optional[str]="1.0.0", + template: Optional[List[Tuple[str, str]]]=None, + variables: Optional[Dict[str, Any]]=None, + example_variable_keys: Optional[List[str]]=None, + constraint_variable_keys: Optional[List[str]]=None, + rag_context_variable_keys: Optional[List[str]]=None, + rag_query_variable_keys: Optional[List[str]]=None, + ml_app: str="") -> AnnotationContext: + """ + shortcut to create a prompt object and annotate it + """ + # TODO try to check for if the prompt already exists within the span and update it + prompt = Prompt(name, version, template, variables, example_variable_keys, constraint_variable_keys, + rag_context_variable_keys, rag_query_variable_keys, ml_app) + return cls.annotation_context(prompt=prompt) + @classmethod def annotate( cls, span: Optional[Span] = None, - prompt: Optional[dict] = None, + prompt: Optional[Prompt] = None, input_data: Optional[Any] = None, output_data: Optional[Any] = None, metadata: Optional[Dict[str, Any]] = None, @@ -823,15 +842,8 @@ def annotate( :param Span span: Span to annotate. If no span is provided, the current active span will be used. Must be an LLMObs-type span, i.e. generated by the LLMObs SDK. - :param prompt: A dictionary that represents the prompt used for an LLM call in the following form: - `{"template": "...", "id": "...", "version": "...", "variables": {"variable_1": "...", ...}}`. - Can also be set using the `ddtrace.llmobs.utils.Prompt` constructor class. - - This argument is only applicable to LLM spans. - - The dictionary may contain two optional keys relevant to RAG applications: - `rag_context_variables` - a list of variable key names that contain ground - truth context information - `rag_query_variables` - a list of variable key names that contains query - information for an LLM call + :param prompt: An instance of the `ddtrace.llmobs.utils.Prompt` class that represents the prompt used for an LLM call. + - This argument is only applicable to LLM spans. :param input_data: A single input string, dictionary, or a list of dictionaries based on the span kind: - llm spans: accepts a string, or a dictionary of form {"content": "...", "role": "..."}, or a list of dictionaries with the same signature. @@ -883,8 +895,8 @@ def annotate( span.name = _name if prompt is not None: try: - validated_prompt = validate_prompt(prompt) - cls._set_dict_attribute(span, INPUT_PROMPT, validated_prompt) + dict_prompt = prompt.prepare_prompt(ml_app=_get_ml_app(span) or "") + cls._set_dict_attribute(span, INPUT_PROMPT, dict_prompt) except TypeError: log.warning("Failed to validate prompt with error: ", exc_info=True) if not span_kind: diff --git a/ddtrace/llmobs/_utils.py b/ddtrace/llmobs/_utils.py index f178582f51..92a54bba54 100644 --- a/ddtrace/llmobs/_utils.py +++ b/ddtrace/llmobs/_utils.py @@ -20,59 +20,12 @@ from ddtrace.llmobs._constants import OPENAI_APM_SPAN_NAME from ddtrace.llmobs._constants import SESSION_ID from ddtrace.llmobs._constants import VERTEXAI_APM_SPAN_NAME +from ddtrace.llmobs.utils import Prompt from ddtrace.trace import Span log = get_logger(__name__) - -def validate_prompt(prompt: dict) -> Dict[str, Union[str, dict, List[str]]]: - validated_prompt = {} # type: Dict[str, Union[str, dict, List[str]]] - if not isinstance(prompt, dict): - raise TypeError("Prompt must be a dictionary") - variables = prompt.get("variables") - template = prompt.get("template") - version = prompt.get("version") - prompt_id = prompt.get("id") - ctx_variable_keys = prompt.get("rag_context_variables") - rag_query_variable_keys = prompt.get("rag_query_variables") - if variables is not None: - if not isinstance(variables, dict): - raise TypeError("Prompt variables must be a dictionary.") - if not any(isinstance(k, str) or isinstance(v, str) for k, v in variables.items()): - raise TypeError("Prompt variable keys and values must be strings.") - validated_prompt["variables"] = variables - if template is not None: - if not isinstance(template, str): - raise TypeError("Prompt template must be a string") - validated_prompt["template"] = template - if version is not None: - if not isinstance(version, str): - raise TypeError("Prompt version must be a string.") - validated_prompt["version"] = version - if prompt_id is not None: - if not isinstance(prompt_id, str): - raise TypeError("Prompt id must be a string.") - validated_prompt["id"] = prompt_id - if ctx_variable_keys is not None: - if not isinstance(ctx_variable_keys, list): - raise TypeError("Prompt field `context_variable_keys` must be a list of strings.") - if not all(isinstance(k, str) for k in ctx_variable_keys): - raise TypeError("Prompt field `context_variable_keys` must be a list of strings.") - validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ctx_variable_keys - else: - validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ["context"] - if rag_query_variable_keys is not None: - if not isinstance(rag_query_variable_keys, list): - raise TypeError("Prompt field `rag_query_variables` must be a list of strings.") - if not all(isinstance(k, str) for k in rag_query_variable_keys): - raise TypeError("Prompt field `rag_query_variables` must be a list of strings.") - validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = rag_query_variable_keys - else: - validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = ["question"] - return validated_prompt - - class LinkTracker: def __init__(self, object_span_links=None): self._object_span_links = object_span_links or {} @@ -185,7 +138,7 @@ def _get_session_id(span: Span) -> Optional[str]: def _inject_llmobs_parent_id(span_context): """Inject the LLMObs parent ID into the span context for reconnecting distributed LLMObs traces.""" span = ddtrace.tracer.current_span() - + if span is None: log.warning("No active span to inject LLMObs parent ID info.") return diff --git a/ddtrace/llmobs/experimentation/_experiments.py b/ddtrace/llmobs/experimentation/_experiments.py index 590d3caae4..b2c5091abf 100644 --- a/ddtrace/llmobs/experimentation/_experiments.py +++ b/ddtrace/llmobs/experimentation/_experiments.py @@ -13,6 +13,7 @@ from .._utils import HTTPResponse from .._utils import http_request +from ..utils import Prompt from ..decorators import agent from .._llmobs import LLMObs @@ -624,6 +625,7 @@ class Experiment: name (str): Name of the experiment task (Callable): Function that processes each dataset record dataset (Dataset): Dataset to run the experiment on + prompt (Prompt): Prompt template for the experiment evaluators (List[Callable]): Functions that evaluate task outputs tags (List[str]): Tags for organizing experiments description (str): Description of the experiment @@ -640,6 +642,7 @@ def __init__( name: str, task: Callable, dataset: Dataset, + prompt: Prompt, evaluators: List[Callable], tags: List[str] = [], description: str = "", @@ -649,6 +652,7 @@ def __init__( self.name = name self.task = task self.dataset = dataset + self.prompt = prompt self.evaluators = evaluators self.tags = tags self.project_name = ENV_PROJECT_NAME @@ -997,6 +1001,7 @@ def process_row(idx_row): LLMObs.annotate( span, + prompt=self.prompt, input_data=input_data, output_data=output, tags={ @@ -1033,6 +1038,7 @@ def process_row(idx_row): LLMObs.annotate( span, input_data=input_data, + prompt=self.prompt, tags={ "dataset_id": self.dataset._datadog_dataset_id, "dataset_record_id": row["record_id"], diff --git a/ddtrace/llmobs/utils.py b/ddtrace/llmobs/utils.py index dac1f3149c..231c9cb752 100644 --- a/ddtrace/llmobs/utils.py +++ b/ddtrace/llmobs/utils.py @@ -1,7 +1,14 @@ +from re import match +from hashlib import sha1 +from typing import Any from typing import Dict from typing import List +from typing import Optional +from typing import Set +from typing import Tuple from typing import Union - +from ddtrace.llmobs._constants import INTERNAL_CONTEXT_VARIABLE_KEYS +from ddtrace.llmobs._constants import INTERNAL_QUERY_VARIABLE_KEYS # TypedDict was added to typing in python 3.8 try: @@ -19,20 +26,228 @@ ExportedLLMObsSpan = TypedDict("ExportedLLMObsSpan", {"span_id": str, "trace_id": str}) Document = TypedDict("Document", {"name": str, "id": str, "text": str, "score": float}, total=False) Message = TypedDict("Message", {"content": str, "role": str}, total=False) -Prompt = TypedDict( - "Prompt", - { - "variables": Dict[str, str], - "template": str, - "id": str, - "version": str, - "rag_context_variables": List[ - str - ], # a list of variable key names that contain ground truth context information - "rag_query_variables": List[str], # a list of variable key names that contains query information - }, - total=False, -) + +class Prompt: + """ + Represents a prompt used for an LLM call. + + Attributes: + name (str): The name of the prompt. + ml_app (str): The name of the service, retrieved from the active span when not specified. + version (str): The version of the prompt. + prompt_template_id (int): A hash of name and ml_app, used to identify the prompt template. + prompt_instance_id (int): A hash of all prompt attributes, used to identify the prompt instance. + template (Union[List[Tuple[str, str]], str]): The template used for the prompt, which can be a list of tuples or a string. + variables (Dict[str, str]): A dictionary of variables used in the prompt. + example_variable_keys (List[str]): A list of variables names denoting examples. Examples are used to improve accuracy for the prompt. + constraint_variable_keys (List[str]): A list of variables names denoting constraints. Constraints are limitations on how the prompt result is displayed. + rag_context_variable_keys (List[str]): A list of variable key names that contain ground truth context information. + rag_query_variable_keys (List[str]): A list of variable key names that contain query information for an LLM call. + """ + name: str + version: Optional[str] + prompt_template_id: str + prompt_instance_id: str + template: Optional[List[Tuple[str, str]]] + variables: Optional[Dict[str, Any]] + example_variable_keys: Optional[List[str]] + constraint_variable_keys: Optional[List[str]] + rag_context_variable_keys: Optional[List[str]] + rag_query_variable_keys: Optional[List[str]] + ml_app: str + + def __init__(self, + name, + version = "1.0.0", + template = None, + variables = None, + example_variable_keys = None, + constraint_variable_keys = None, + rag_context_variable_keys = None, + rag_query_variable_keys = None, + ml_app=""): + + self.__dict__["_is_initialized"] = False + + if name is None: + raise TypeError("Prompt name of type String is mandatory.") + + self.__dict__["name"] = name + + # Default values + template = template or [] + variables = variables or {} + example_variable_keys = example_variable_keys or ["example", "examples"] + constraint_variable_keys = constraint_variable_keys or ["constraint", "constraints"] + rag_context_variable_keys = rag_context_variable_keys or ["context"] + rag_query_variable_keys = rag_query_variable_keys or ["question"] + version = version or "1.0.0" + + if version is not None: + # Add minor and patch version if not present + version_parts = (version.split(".") + ["0", "0"])[:3] + version = ".".join(version_parts) + + # Accept simple string templates as user role + if isinstance(template, str): + template = [("user", template)] + + self.__dict__["ml_app"] = ml_app + self.__dict__["version"] = version + self.__dict__["template"] = template + self.__dict__["variables"] = variables + self.__dict__["example_variable_keys"] = example_variable_keys + self.__dict__["constraint_variable_keys"] = constraint_variable_keys + self.__dict__["rag_context_variable_keys"] = rag_context_variable_keys + self.__dict__["rag_query_variable_keys"] = rag_query_variable_keys + + # Unlocks the id regeneration at each setattr call + self.__dict__["_is_initialized"] = True + + def generate_ids(self): + """ + Generates prompt_template_id and prompt_instance_id based on the prompt attributes. + The prompt_template_id is a sha-1 hash of the prompt name and ml_app + The prompt_instance_id is a sha-1 hash of all prompt attributes. + """ + name = str(self.name) + ml_app = str(self.ml_app) + version = str(self.version) + template = str(self.template) + variables = str(self.variables) + example_variable_keys = str(self.example_variable_keys) + constraint_variable_keys = str(self.constraint_variable_keys) + rag_context_variable_keys = str(self.rag_context_variable_keys) + rag_query_variable_keys = str(self.rag_query_variable_keys) + + template_id_str = f"[{ml_app}]{name}" + instance_id_str = f"[{ml_app}]{name}{version}{template}{variables}{example_variable_keys}{constraint_variable_keys}{rag_context_variable_keys}{rag_query_variable_keys}" + + self.__dict__["prompt_template_id"] = sha1(template_id_str.encode()).hexdigest() + self.__dict__["prompt_instance_id"] = sha1(instance_id_str.encode()).hexdigest() + + def validate(self): + errors = [] + prompt_template_id = self.prompt_template_id + prompt_instance_id = self.prompt_instance_id + name = self.name + version = self.version + template = self.template + variables = self.variables + example_variable_keys = self.example_variable_keys + constraint_variable_keys = self.constraint_variable_keys + rag_context_variable_keys = self.rag_context_variable_keys + rag_query_variable_keys = self.rag_query_variable_keys + + + if prompt_template_id is None: + self.generate_ids() + elif not isinstance(prompt_template_id, str): + errors.append("Prompt template id must be a string.") + if prompt_instance_id is None: + self.generate_ids() + elif not isinstance(prompt_instance_id, str): + errors.append("Prompt instance id must be a string.") + + if name is None: + errors.append("Prompt name of type String is mandatory.") + elif not isinstance(name, str): + errors.append("Prompt name must be a string.") + + if version is not None: + # Add minor and patch version if not present + version_parts = (version.split(".") + ["0", "0"])[:3] + version = ".".join(version_parts) + # Official semver regex from https://semver.org/ + semver_regex = ( + r'^(?P0|[1-9]\d*)\.' + r'(?P0|[1-9]\d*)\.' + r'(?P0|[1-9]\d*)' + r'(?:-(?P(?:0|[1-9]\d*|\d*[a-zA-Z-]' + r'[0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-]' + r'[0-9a-zA-Z-]*))*))?' + r'(?:\+(?P[0-9a-zA-Z-]+' + r'(?:\.[0-9a-zA-Z-]+)*))?$' + ) + if not bool(match(semver_regex, version)): + errors.append( + "Prompt version must be semver compatible. Please check https://semver.org/ for more information.") + + # Accept simple string templates + if isinstance(template, str): + template = [("user", template)] + + # validate template + if not (isinstance(template, list) and all(isinstance(t, tuple) for t in template)): + errors.append("Prompt template must be a list of tuples.") + if not all(len(t) == 2 for t in template): + errors.append("Prompt template tuples must have exactly two elements.") + if not all(isinstance(item[0], str) and isinstance(item[1], str) for item in template): + errors.append("Prompt template tuple elements must be strings.") + + if not isinstance(variables, dict): + errors.append("Prompt variables must be a dictionary.") + if not all(isinstance(k, str) for k in variables): + errors.append("Prompt variable keys must be strings.") + + for var_list in [example_variable_keys, constraint_variable_keys, rag_context_variable_keys, rag_query_variable_keys]: + if not all(isinstance(var, str) for var in var_list): + errors.append("All variable lists must contain strings only.") + + if errors: + raise TypeError("\n".join(errors)) + + return errors + + def to_tags_dict(self) -> Dict[str, Union[str, Set[str], Dict[str, str], List[Tuple[str, str]]]]: + name = self.name + version = self.version + prompt_template_id = self.prompt_template_id + prompt_instance_id = self.prompt_instance_id + template = self.template + variables = self.variables + example_variable_keys = self.example_variable_keys + constraint_variable_keys = self.constraint_variable_keys + rag_context_variable_keys = self.rag_context_variable_keys + rag_query_variable_keys = self.rag_query_variable_keys + + # Clean up keys and remove those that are not in variables, including default keys. + example_variable_keys_set = {key for key in example_variable_keys if key in variables} + constraint_variable_keys_set = {key for key in constraint_variable_keys if key in variables} + rag_context_variable_keys_set = {key for key in rag_context_variable_keys if key in variables} + rag_query_variable_keys_set = {key for key in rag_query_variable_keys if key in variables} + + return { + "name": name, + "version": version, + "prompt_template_id": prompt_template_id, + "prompt_instance_id": prompt_instance_id, + "template": template, + "variables": variables, + "example_variable_keys": example_variable_keys_set, + "constraint_variable_keys": constraint_variable_keys_set, + # also using internal constants to keep hallucination functionality + INTERNAL_CONTEXT_VARIABLE_KEYS: rag_context_variable_keys_set, + INTERNAL_QUERY_VARIABLE_KEYS: rag_query_variable_keys_set, + } + + def prepare_prompt(self, ml_app=None) -> Dict[str, Union[str, List[str], Dict[str, str], List[Tuple[str, str]]]]: + if ml_app: + # regenerate ids if ml_app is changed + self.__dict__["ml_app"] = ml_app + self.generate_ids() + self.validate() + return self.to_tags_dict() + + def __setattr__(self, name, value): + """ + Overrides Set attribute value to regenerate prompt ids if attributes change. + """ + super().__setattr__(name, value) + if self.__dict__.get("_is_initialized"): + self.generate_ids() + + class Messages: