-
Notifications
You must be signed in to change notification settings - Fork 51
Remove the dependency on braintrust core. #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
clutchski
wants to merge
1
commit into
main
Choose a base branch
from
matt/no-btcore-dep
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from braintrust_core.score import Scorer | ||
| from .score import Scorer | ||
|
|
||
|
|
||
| class ScorerWithPartial(Scorer): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import dataclasses | ||
| import sys | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| from .serializable_data_class import SerializableDataClass | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Score(SerializableDataClass): | ||
| """A score for an evaluation. The score is a float between 0 and 1.""" | ||
|
|
||
| name: str | ||
| """The name of the score. This should be a unique name for the scorer.""" | ||
|
|
||
| score: Optional[float] | ||
| """The score for the evaluation. This should be a float between 0 and 1. If the score is None, the evaluation is considered to be skipped.""" | ||
|
|
||
| metadata: Dict[str, Any] = dataclasses.field(default_factory=dict) | ||
| """Metadata for the score. This can be used to store additional information about the score.""" | ||
|
|
||
| # DEPRECATION_NOTICE: this field is deprecated, as errors are propagated up to the caller. | ||
| error: Optional[Exception] = None | ||
| """Deprecated: The error field is deprecated, as errors are now propagated to the caller. The field will be removed in a future version of the library.""" | ||
|
|
||
| def as_dict(self): | ||
| return { | ||
| "score": self.score, | ||
| "metadata": self.metadata, | ||
| } | ||
|
|
||
| def __post_init__(self): | ||
| if self.score is not None and (self.score < 0 or self.score > 1): | ||
| raise ValueError(f"score ({self.score}) must be between 0 and 1") | ||
| if self.error is not None: | ||
| print( | ||
| "The error field is deprecated, as errors are now propagated to the caller. The field will be removed in a future version of the library", | ||
| sys.stderr, | ||
| ) | ||
|
|
||
|
|
||
| class Scorer(ABC): | ||
| async def eval_async(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: | ||
| return await self._run_eval_async(output, expected, **kwargs) | ||
|
|
||
| def eval(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: | ||
| return self._run_eval_sync(output, expected, **kwargs) | ||
|
|
||
| def __call__(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: | ||
| return self.eval(output, expected, **kwargs) | ||
|
|
||
| async def _run_eval_async(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: | ||
| # By default we just run the sync version in a thread | ||
| return self._run_eval_sync(output, expected, **kwargs) | ||
|
|
||
| def _name(self) -> str: | ||
| return self.__class__.__name__ | ||
|
|
||
| @abstractmethod | ||
| def _run_eval_sync(self, output: Any, expected: Any = None, **kwargs: Any) -> Score: | ||
| ... | ||
|
|
||
|
|
||
| __all__ = ["Score", "Scorer"] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import dataclasses | ||
| import json | ||
| from typing import Dict, Union, get_origin | ||
|
|
||
|
|
||
| class SerializableDataClass: | ||
| def as_dict(self): | ||
| """Serialize the object to a dictionary.""" | ||
| return dataclasses.asdict(self) | ||
|
|
||
| def as_json(self, **kwargs): | ||
| """Serialize the object to JSON.""" | ||
| return json.dumps(self.as_dict(), **kwargs) | ||
|
|
||
| def __getitem__(self, item: str): | ||
| return getattr(self, item) | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, d: Dict): | ||
| """Deserialize the object from a dictionary. This method | ||
| is shallow and will not call from_dict() on nested objects.""" | ||
| fields = set(f.name for f in dataclasses.fields(cls)) | ||
| filtered = {k: v for k, v in d.items() if k in fields} | ||
| return cls(**filtered) | ||
|
|
||
| @classmethod | ||
| def from_dict_deep(cls, d: Dict): | ||
| """Deserialize the object from a dictionary. This method | ||
| is deep and will call from_dict_deep() on nested objects.""" | ||
| fields = {f.name: f for f in dataclasses.fields(cls)} | ||
| filtered = {} | ||
| for k, v in d.items(): | ||
| if k not in fields: | ||
| continue | ||
|
|
||
| if ( | ||
| isinstance(v, dict) | ||
| and isinstance(fields[k].type, type) | ||
| and issubclass(fields[k].type, SerializableDataClass) | ||
| ): | ||
| filtered[k] = fields[k].type.from_dict_deep(v) | ||
| elif get_origin(fields[k].type) == Union: | ||
| for t in fields[k].type.__args__: | ||
| if t == type(None) and v is None: | ||
| filtered[k] = None | ||
| break | ||
| if isinstance(t, type) and issubclass(t, SerializableDataClass) and v is not None: | ||
| try: | ||
| filtered[k] = t.from_dict_deep(v) | ||
| break | ||
| except TypeError: | ||
| pass | ||
| else: | ||
| filtered[k] = v | ||
| elif ( | ||
| isinstance(v, list) | ||
| and get_origin(fields[k].type) == list | ||
| and len(fields[k].type.__args__) == 1 | ||
| and isinstance(fields[k].type.__args__[0], type) | ||
| and issubclass(fields[k].type.__args__[0], SerializableDataClass) | ||
| ): | ||
| filtered[k] = [fields[k].type.__args__[0].from_dict_deep(i) for i in v] | ||
| else: | ||
| filtered[k] = v | ||
| return cls(**filtered) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import unittest | ||
| from dataclasses import dataclass | ||
| from typing import List, Optional | ||
|
|
||
| from .serializable_data_class import SerializableDataClass | ||
|
|
||
|
|
||
| @dataclass | ||
| class PromptData(SerializableDataClass): | ||
| prompt: Optional[str] = None | ||
| options: Optional[dict] = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class PromptSchema(SerializableDataClass): | ||
| id: str | ||
| project_id: str | ||
| _xact_id: str | ||
| name: str | ||
| slug: str | ||
| description: Optional[str] | ||
| prompt_data: PromptData | ||
| tags: Optional[List[str]] | ||
|
|
||
|
|
||
| class TestSerializableDataClass(unittest.TestCase): | ||
| def test_from_dict_deep_with_none_values(self): | ||
| """Test that from_dict_deep correctly handles None values in nested objects.""" | ||
| test_dict = { | ||
| "id": "456", | ||
| "project_id": "123", | ||
| "_xact_id": "789", | ||
| "name": "test-prompt", | ||
| "slug": "test-prompt", | ||
| "description": None, | ||
| "prompt_data": {"prompt": None, "options": None}, | ||
| "tags": None, | ||
| } | ||
|
|
||
| prompt = PromptSchema.from_dict_deep(test_dict) | ||
|
|
||
| # Verify all fields were set correctly. | ||
| self.assertEqual(prompt.id, "456") | ||
| self.assertEqual(prompt.project_id, "123") | ||
| self.assertEqual(prompt._xact_id, "789") | ||
| self.assertEqual(prompt.name, "test-prompt") | ||
| self.assertEqual(prompt.slug, "test-prompt") | ||
| self.assertIsNone(prompt.description) | ||
| self.assertIsNone(prompt.tags) | ||
|
|
||
| # Verify nested object was created and its fields are None. | ||
| self.assertIsInstance(prompt.prompt_data, PromptData) | ||
| self.assertIsNone(prompt.prompt_data.prompt) | ||
| self.assertIsNone(prompt.prompt_data.options) | ||
|
|
||
| # Verify round-trip serialization works. | ||
| round_trip = PromptSchema.from_dict_deep(prompt.as_dict()) | ||
| self.assertEqual(round_trip.as_dict(), test_dict) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
printstatement is usingsys.stderras a positional argument instead of as the file parameter. This will cause the message to print incorrectly. The statement should be modified to:Spotted by Diamond
Is this helpful? React 👍 or 👎 to let us know.