diff --git a/evalbench/scorers/setmatcher.py b/evalbench/scorers/setmatcher.py index 4493fbc3..577362d0 100644 --- a/evalbench/scorers/setmatcher.py +++ b/evalbench/scorers/setmatcher.py @@ -43,19 +43,42 @@ def compare( return 0, None else: try: - # Current results are a list of Dict. Converting to Tuple for set comparison - golden_execution_result_tuple = [ - tuple(d.values()) for d in golden_execution_result - ] - generated_execution_result_tuple = [ - tuple(d.values()) for d in generated_execution_result - ] - score = ( - 100 - if set(golden_execution_result_tuple) - == set(generated_execution_result_tuple) - else 0 - ) + def _is_document_structure(data): + if not isinstance(data, list): + return False + for item in data: + if isinstance(item, dict): + for v in item.values(): + if isinstance(v, (dict, list)): + return True + return False + + if _is_document_structure(golden_execution_result) or _is_document_structure(generated_execution_result): + def _make_hashable(item): + if isinstance(item, list): + return tuple(_make_hashable(x) for x in item) + elif isinstance(item, dict): + return tuple(sorted((k, _make_hashable(v)) for k, v in item.items())) + else: + return item + + h1 = [_make_hashable(d) for d in golden_execution_result] + h2 = [_make_hashable(d) for d in generated_execution_result] + score = 100 if sorted(h1) == sorted(h2) else 0 + else: + # SQL Model: flat primitives, ignore column names, remove duplicates + golden_execution_result_tuple = [ + tuple(d.values()) for d in golden_execution_result + ] + generated_execution_result_tuple = [ + tuple(d.values()) for d in generated_execution_result + ] + score = ( + 100 + if set(golden_execution_result_tuple) + == set(generated_execution_result_tuple) + else 0 + ) except Exception as e: return 0, str(e) diff --git a/evalbench/test/set_matcher_test.py b/evalbench/test/set_matcher_test.py new file mode 100644 index 00000000..ffa4dcf7 --- /dev/null +++ b/evalbench/test/set_matcher_test.py @@ -0,0 +1,70 @@ +import unittest +from scorers.setmatcher import SetMatcher + + +class TestSetMatcher(unittest.TestCase): + + # --- Classic SQL Set Cases (Backwards Compatible) --- + + def test_sql_flat_match(self): + matcher = SetMatcher({}) + golden = [{"a": 1, "b": 2}] + generated = [{"a": 1, "b": 2}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 100) + self.assertIsNone(err) + + def test_sql_ignore_duplicates(self): + """Classic SQL removes duplicate rows.""" + matcher = SetMatcher({}) + golden = [{"a": 1}, {"a": 1}] + generated = [{"a": 1}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 100) + self.assertIsNone(err) + + def test_sql_ignore_keys(self): + """Classic SQL compares values only.""" + matcher = SetMatcher({}) + golden = [{"a": 1}] + generated = [{"b": 1}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 100) + self.assertIsNone(err) + + # --- Document / NoSQL Cases (Auto-detected) --- + + def test_doc_nested_dict_match(self): + matcher = SetMatcher({}) + golden = [{"a": {"x": 1}}] + generated = [{"a": {"x": 1}}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 100) + self.assertIsNone(err) + + def test_doc_multiset_duplication(self): + """Document evaluation respects duplicate document counts.""" + matcher = SetMatcher({}) + golden = [{"a": {"x": 1}}, {"a": {"x": 1}}] + generated = [{"a": {"x": 1}}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 0) # Should fail if counts don't match for docs + self.assertIsNone(err) + + def test_doc_nested_list_preserve_order(self): + matcher = SetMatcher({}) + golden = [{"a": [1, 2]}] + generated = [{"a": [2, 1]}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 0) # Order inside lists matters for docs + + def test_doc_nested_list_match(self): + matcher = SetMatcher({}) + golden = [{"a": [1, 2]}] + generated = [{"a": [1, 2]}] + score, err = matcher.compare(None, None, None, golden, None, None, None, generated, None, None) + self.assertEqual(score, 100) + + +if __name__ == '__main__': + unittest.main()