Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions evalbench/scorers/setmatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,42 @@ def compare(
return 0, None
else:
try:
# Current results are a list of Dict. Converting to Tuple for set comparison
golden_execution_result_tuple = [
tuple(d.values()) for d in golden_execution_result
]
generated_execution_result_tuple = [
tuple(d.values()) for d in generated_execution_result
]
score = (
100
if set(golden_execution_result_tuple)
== set(generated_execution_result_tuple)
else 0
)
def _is_document_structure(data):
if not isinstance(data, list):
return False
for item in data:
if isinstance(item, dict):
for v in item.values():
if isinstance(v, (dict, list)):
return True
return False

if _is_document_structure(golden_execution_result) or _is_document_structure(generated_execution_result):
def _make_hashable(item):
if isinstance(item, list):
return tuple(_make_hashable(x) for x in item)
elif isinstance(item, dict):
return tuple(sorted((k, _make_hashable(v)) for k, v in item.items()))
else:
return item

h1 = [_make_hashable(d) for d in golden_execution_result]
h2 = [_make_hashable(d) for d in generated_execution_result]
score = 100 if sorted(h1) == sorted(h2) else 0
else:
# SQL Model: flat primitives, ignore column names, remove duplicates
golden_execution_result_tuple = [
tuple(d.values()) for d in golden_execution_result
]
generated_execution_result_tuple = [
tuple(d.values()) for d in generated_execution_result
]
score = (
100
if set(golden_execution_result_tuple)
== set(generated_execution_result_tuple)
else 0
)
except Exception as e:
return 0, str(e)

Expand Down
70 changes: 70 additions & 0 deletions evalbench/test/set_matcher_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest
from scorers.setmatcher import SetMatcher


class TestSetMatcher(unittest.TestCase):

# --- Classic SQL Set Cases (Backwards Compatible) ---

def test_sql_flat_match(self):
matcher = SetMatcher({})
golden = [{"a": 1, "b": 2}]
generated = [{"a": 1, "b": 2}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 100)
self.assertIsNone(err)

def test_sql_ignore_duplicates(self):
"""Classic SQL removes duplicate rows."""
matcher = SetMatcher({})
golden = [{"a": 1}, {"a": 1}]
generated = [{"a": 1}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 100)
self.assertIsNone(err)

def test_sql_ignore_keys(self):
"""Classic SQL compares values only."""
matcher = SetMatcher({})
golden = [{"a": 1}]
generated = [{"b": 1}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 100)
self.assertIsNone(err)

# --- Document / NoSQL Cases (Auto-detected) ---

def test_doc_nested_dict_match(self):
matcher = SetMatcher({})
golden = [{"a": {"x": 1}}]
generated = [{"a": {"x": 1}}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 100)
self.assertIsNone(err)

def test_doc_multiset_duplication(self):
"""Document evaluation respects duplicate document counts."""
matcher = SetMatcher({})
golden = [{"a": {"x": 1}}, {"a": {"x": 1}}]
generated = [{"a": {"x": 1}}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 0) # Should fail if counts don't match for docs
self.assertIsNone(err)

def test_doc_nested_list_preserve_order(self):
matcher = SetMatcher({})
golden = [{"a": [1, 2]}]
generated = [{"a": [2, 1]}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 0) # Order inside lists matters for docs

def test_doc_nested_list_match(self):
matcher = SetMatcher({})
golden = [{"a": [1, 2]}]
generated = [{"a": [1, 2]}]
score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None)
self.assertEqual(score, 100)


if __name__ == '__main__':
unittest.main()
Loading