Skip to content

Commit

Permalink
Added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Erechtheus committed Feb 28, 2025
1 parent c5335b3 commit 28bf836
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def check_errors(goldstandard, predictions):
if unknown_labels:
errors.append(f"Unknown labels found in predictions: {set(unknown_labels)}")

unknown_keys = [key for key in predictions.keys() if key not in goldstandard.keys()]
if unknown_keys:
errors.append(f"Unknown keys found in predictions: {set(unknown_keys)}")

if errors:
print("Errors found in predictions:")
for error in errors:
Expand Down
84 changes: 83 additions & 1 deletion tests/test_check_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,86 @@ def mock_print(msg): # Capture print statements

builtins.print = original_print # Restore original print

assert "Missing predictions for IDs:" in errors[1] # Expect an error message
assert "Missing predictions for IDs:" in errors[1] # Expect an error message

def test_check_errors_DuplicatePrediction():
gold_df = pd.DataFrame({
'id': ['de_0', 'en_0', 'fr_0'],
'label': [0, 1, 0]
})
gold_df = gold_df.set_index('id')
predictions = pd.DataFrame({
'id': ['de_0', 'en_0', 'fr_0', 'fr_0'],
'predicted_label': [0, 1, 0, 1]
})
predictions = predictions.set_index('id')['predicted_label'].to_dict()
errors = []

def mock_print(msg): # Capture print statements
errors.append(msg)

# Monkey-patch print to capture output
import builtins
original_print = builtins.print
builtins.print = mock_print

check_errors(gold_df, predictions)

builtins.print = original_print # Restore original print

assert "Duplicate prediction entries for IDs:" in errors[1] # Expect an error message

def test_check_errors_UnknownLabels():
gold_df = pd.DataFrame({
'id': ['de_0', 'en_0', 'fr_0'],
'label': [0, 1, 0]
})
gold_df = gold_df.set_index('id')
predictions = pd.DataFrame({
'id': ['de_0', 'en_0', 'fr_0'],
'predicted_label': [0, 1, 2]
})
predictions = predictions.set_index('id')['predicted_label'].to_dict()
errors = []

def mock_print(msg): # Capture print statements
errors.append(msg)

# Monkey-patch print to capture output
import builtins
original_print = builtins.print
builtins.print = mock_print

check_errors(gold_df, predictions)

builtins.print = original_print # Restore original print

assert "Unknown labels found in predictions:" in errors[1] # Expect error


def test_check_errors_UnknownKeys():
gold_df = pd.DataFrame({
'id': ['de_0', 'en_0', 'fr_0'],
'label': [0, 1, 0]
})
gold_df = gold_df.set_index('id')
predictions = pd.DataFrame({
'id': ['de_0', 'en_0', 'fr_0', 'en_1'],
'predicted_label': [0, 1, 0, 1]
})
predictions = predictions.set_index('id')['predicted_label'].to_dict()
errors = []

def mock_print(msg): # Capture print statements
errors.append(msg)

# Monkey-patch print to capture output
import builtins
original_print = builtins.print
builtins.print = mock_print

check_errors(gold_df, predictions)

builtins.print = original_print # Restore original print

assert "Unknown keys found in predictions:" in errors[1] # Expect error

0 comments on commit 28bf836

Please sign in to comment.