diff --git a/py/autoevals/__init__.py b/py/autoevals/__init__.py index 64e3a5c..1d61f3e 100644 --- a/py/autoevals/__init__.py +++ b/py/autoevals/__init__.py @@ -120,8 +120,6 @@ async def evaluate_qa(): See individual module documentation for detailed usage and options. """ -from braintrust_core.score import Score, Scorer - from .json import * from .list import * from .llm import * @@ -129,5 +127,6 @@ async def evaluate_qa(): from .number import * from .oai import init from .ragas import * +from .score import Score, Scorer from .string import * from .value import ExactMatch diff --git a/py/autoevals/json.py b/py/autoevals/json.py index 6e61b98..23147d0 100644 --- a/py/autoevals/json.py +++ b/py/autoevals/json.py @@ -15,12 +15,12 @@ import json -from braintrust_core.score import Score, Scorer from jsonschema import ValidationError, validate from autoevals.partial import ScorerWithPartial from .number import NumericDiff +from .score import Score, Scorer from .string import Levenshtein diff --git a/py/autoevals/list.py b/py/autoevals/list.py index 42f4ff3..6f5d2aa 100644 --- a/py/autoevals/list.py +++ b/py/autoevals/list.py @@ -1,9 +1,8 @@ import sys -from braintrust_core.score import Score - from autoevals.partial import ScorerWithPartial +from .score import Score from .string import Levenshtein diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index ad083cf..7759f09 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -53,11 +53,11 @@ import chevron import yaml -from braintrust_core.score import Score from autoevals.partial import ScorerWithPartial from .oai import Client, arun_cached_request, run_cached_request +from .score import Score # Disable HTML escaping in chevron. chevron.renderer._html_escape = lambda x: x # type: ignore[attr-defined] diff --git a/py/autoevals/moderation.py b/py/autoevals/moderation.py index 51e8864..76090ba 100644 --- a/py/autoevals/moderation.py +++ b/py/autoevals/moderation.py @@ -1,10 +1,9 @@ from typing import Optional -from braintrust_core.score import Score - from autoevals.llm import OpenAIScorer from .oai import Client, arun_cached_request, run_cached_request +from .score import Score REQUEST_TYPE = "moderation" diff --git a/py/autoevals/number.py b/py/autoevals/number.py index 610b627..408e8bc 100644 --- a/py/autoevals/number.py +++ b/py/autoevals/number.py @@ -11,10 +11,10 @@ - Suitable for both small and large number comparisons """ -from braintrust_core.score import Score - from autoevals.partial import ScorerWithPartial +from .score import Score + class NumericDiff(ScorerWithPartial): """Numeric similarity scorer using normalized difference. diff --git a/py/autoevals/partial.py b/py/autoevals/partial.py index a3fa8a5..290b6be 100644 --- a/py/autoevals/partial.py +++ b/py/autoevals/partial.py @@ -1,4 +1,4 @@ -from braintrust_core.score import Scorer +from .score import Scorer class ScorerWithPartial(Scorer): diff --git a/py/autoevals/score.py b/py/autoevals/score.py new file mode 100644 index 0000000..e0acb38 --- /dev/null +++ b/py/autoevals/score.py @@ -0,0 +1,64 @@ +import dataclasses +import sys +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from .serializable_data_class import SerializableDataClass + + +@dataclasses.dataclass +class Score(SerializableDataClass): + """A score for an evaluation. The score is a float between 0 and 1.""" + + name: str + """The name of the score. This should be a unique name for the scorer.""" + + score: Optional[float] + """The score for the evaluation. This should be a float between 0 and 1. If the score is None, the evaluation is considered to be skipped.""" + + metadata: Dict[str, Any] = dataclasses.field(default_factory=dict) + """Metadata for the score. This can be used to store additional information about the score.""" + + # DEPRECATION_NOTICE: this field is deprecated, as errors are propagated up to the caller. + error: Optional[Exception] = None + """Deprecated: The error field is deprecated, as errors are now propagated to the caller. The field will be removed in a future version of the library.""" + + def as_dict(self): + return { + "score": self.score, + "metadata": self.metadata, + } + + def __post_init__(self): + if self.score is not None and (self.score < 0 or self.score > 1): + raise ValueError(f"score ({self.score}) must be between 0 and 1") + if self.error is not None: + print( + "The error field is deprecated, as errors are now propagated to the caller. The field will be removed in a future version of the library", + sys.stderr, + ) + + +class Scorer(ABC): + async def eval_async(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: + return await self._run_eval_async(output, expected, **kwargs) + + def eval(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: + return self._run_eval_sync(output, expected, **kwargs) + + def __call__(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: + return self.eval(output, expected, **kwargs) + + async def _run_eval_async(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: + # By default we just run the sync version in a thread + return self._run_eval_sync(output, expected, **kwargs) + + def _name(self) -> str: + return self.__class__.__name__ + + @abstractmethod + def _run_eval_sync(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: + ... + + +__all__ = ["Score", "Scorer"] diff --git a/py/autoevals/serializable_data_class.py b/py/autoevals/serializable_data_class.py new file mode 100644 index 0000000..2fe62f6 --- /dev/null +++ b/py/autoevals/serializable_data_class.py @@ -0,0 +1,65 @@ +import dataclasses +import json +from typing import Dict, Union, get_origin + + +class SerializableDataClass: + def as_dict(self): + """Serialize the object to a dictionary.""" + return dataclasses.asdict(self) + + def as_json(self, **kwargs): + """Serialize the object to JSON.""" + return json.dumps(self.as_dict(), **kwargs) + + def __getitem__(self, item: str): + return getattr(self, item) + + @classmethod + def from_dict(cls, d: Dict): + """Deserialize the object from a dictionary. This method + is shallow and will not call from_dict() on nested objects.""" + fields = set(f.name for f in dataclasses.fields(cls)) + filtered = {k: v for k, v in d.items() if k in fields} + return cls(**filtered) + + @classmethod + def from_dict_deep(cls, d: Dict): + """Deserialize the object from a dictionary. This method + is deep and will call from_dict_deep() on nested objects.""" + fields = {f.name: f for f in dataclasses.fields(cls)} + filtered = {} + for k, v in d.items(): + if k not in fields: + continue + + if ( + isinstance(v, dict) + and isinstance(fields[k].type, type) + and issubclass(fields[k].type, SerializableDataClass) + ): + filtered[k] = fields[k].type.from_dict_deep(v) + elif get_origin(fields[k].type) == Union: + for t in fields[k].type.__args__: + if t == type(None) and v is None: + filtered[k] = None + break + if isinstance(t, type) and issubclass(t, SerializableDataClass) and v is not None: + try: + filtered[k] = t.from_dict_deep(v) + break + except TypeError: + pass + else: + filtered[k] = v + elif ( + isinstance(v, list) + and get_origin(fields[k].type) == list + and len(fields[k].type.__args__) == 1 + and isinstance(fields[k].type.__args__[0], type) + and issubclass(fields[k].type.__args__[0], SerializableDataClass) + ): + filtered[k] = [fields[k].type.__args__[0].from_dict_deep(i) for i in v] + else: + filtered[k] = v + return cls(**filtered) diff --git a/py/autoevals/string.py b/py/autoevals/string.py index 4bfa069..0dcba9a 100644 --- a/py/autoevals/string.py +++ b/py/autoevals/string.py @@ -20,13 +20,13 @@ import threading from typing import Optional -from braintrust_core.score import Score from polyleven import levenshtein as distance from autoevals.partial import ScorerWithPartial from autoevals.value import normalize_value from .oai import LLMClient, arun_cached_request, run_cached_request +from .score import Score class Levenshtein(ScorerWithPartial): diff --git a/py/autoevals/test_serializable_data_class.py b/py/autoevals/test_serializable_data_class.py new file mode 100644 index 0000000..0cade6a --- /dev/null +++ b/py/autoevals/test_serializable_data_class.py @@ -0,0 +1,62 @@ +import unittest +from dataclasses import dataclass +from typing import List, Optional + +from .serializable_data_class import SerializableDataClass + + +@dataclass +class PromptData(SerializableDataClass): + prompt: Optional[str] = None + options: Optional[dict] = None + + +@dataclass +class PromptSchema(SerializableDataClass): + id: str + project_id: str + _xact_id: str + name: str + slug: str + description: Optional[str] + prompt_data: PromptData + tags: Optional[List[str]] + + +class TestSerializableDataClass(unittest.TestCase): + def test_from_dict_deep_with_none_values(self): + """Test that from_dict_deep correctly handles None values in nested objects.""" + test_dict = { + "id": "456", + "project_id": "123", + "_xact_id": "789", + "name": "test-prompt", + "slug": "test-prompt", + "description": None, + "prompt_data": {"prompt": None, "options": None}, + "tags": None, + } + + prompt = PromptSchema.from_dict_deep(test_dict) + + # Verify all fields were set correctly. + self.assertEqual(prompt.id, "456") + self.assertEqual(prompt.project_id, "123") + self.assertEqual(prompt._xact_id, "789") + self.assertEqual(prompt.name, "test-prompt") + self.assertEqual(prompt.slug, "test-prompt") + self.assertIsNone(prompt.description) + self.assertIsNone(prompt.tags) + + # Verify nested object was created and its fields are None. + self.assertIsInstance(prompt.prompt_data, PromptData) + self.assertIsNone(prompt.prompt_data.prompt) + self.assertIsNone(prompt.prompt_data.options) + + # Verify round-trip serialization works. + round_trip = PromptSchema.from_dict_deep(prompt.as_dict()) + self.assertEqual(round_trip.as_dict(), test_dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/py/autoevals/value.py b/py/autoevals/value.py index 6e79756..7211787 100644 --- a/py/autoevals/value.py +++ b/py/autoevals/value.py @@ -38,10 +38,10 @@ import json from typing import Any -from braintrust_core.score import Score - from autoevals.partial import ScorerWithPartial +from .score import Score + class ExactMatch(ScorerWithPartial): """A scorer that tests for exact equality between values.