diff --git a/chunking_evaluation/__init__.py b/chunking_evaluation/__init__.py index 81f3191..dd5f568 100644 --- a/chunking_evaluation/__init__.py +++ b/chunking_evaluation/__init__.py @@ -1,10 +1,13 @@ from .chunking.base_chunker import BaseChunker from .evaluation_framework.general_evaluation import GeneralEvaluation from .evaluation_framework.synthetic_evaluation import SyntheticEvaluation +from .evaluation_framework.dataset_evaluation import DatasetEvaluation, Dataset from .utils import * __all__ = [ 'BaseChunker', 'GeneralEvaluation', 'SyntheticEvaluation', + 'DatasetEvaluation', + 'Dataset' ] \ No newline at end of file diff --git a/chunking_evaluation/evaluation_framework/base_evaluation.py b/chunking_evaluation/evaluation_framework/base_evaluation.py index a5fb1b8..d15bf18 100644 --- a/chunking_evaluation/evaluation_framework/base_evaluation.py +++ b/chunking_evaluation/evaluation_framework/base_evaluation.py @@ -143,7 +143,7 @@ def _get_chunks_and_metadata(self, splitter): return documents, metadatas def _full_precision_score(self, chunk_metadatas): - ioc_scores = [] + ioc_scores = dict() recall_scores = [] highlighted_chunks_count = [] @@ -200,8 +200,8 @@ def _full_precision_score(self, chunk_metadatas): # Calculate ioc_score if there are numerator sets if numerator_sets: ioc_score = sum_of_ranges(numerator_sets) / sum_of_ranges(denominator_sets) - - ioc_scores.append(ioc_score) + + ioc_scores[index] = ioc_score recall_score = 1 - (sum_of_ranges(unused_highlights) / sum_of_ranges([(x['start_index'], x['end_index']) for x in references])) recall_scores.append(recall_score) @@ -209,9 +209,9 @@ def _full_precision_score(self, chunk_metadatas): return ioc_scores, highlighted_chunks_count def _scores_from_dataset_and_retrievals(self, question_metadatas, highlighted_chunks_count): - iou_scores = [] - recall_scores = [] - precision_scores = [] + iou_scores = dict() + recall_scores = dict() + precision_scores = dict() for (index, row), highlighted_chunk_count, metadatas in zip(self.questions_df.iterrows(), highlighted_chunks_count, question_metadatas): # Unpack question and references # question, references = question_references @@ -259,13 +259,13 @@ def _scores_from_dataset_and_retrievals(self, question_metadatas, highlighted_ch iou_denominator = precision_denominator + sum_of_ranges(unused_highlights) recall_score = numerator_value / recall_denominator - recall_scores.append(recall_score) + recall_scores[index] = recall_score precision_score = numerator_value / precision_denominator - precision_scores.append(precision_score) + precision_scores[index] = precision_score iou_score = numerator_value / iou_denominator - iou_scores.append(iou_score) + iou_scores[index] = iou_score return iou_scores, recall_scores, precision_scores @@ -426,18 +426,21 @@ def run(self, chunker, embedding_function=None, retrieve:int = 5, db_to_save_chu corpora_scores[row['corpus_id']]['recall_scores'].append(recall_scores[index]) corpora_scores[row['corpus_id']]['precision_scores'].append(precision_scores[index]) + brute_iou_scores_vals = list(brute_iou_scores.values()) + brute_iou_mean = np.mean(brute_iou_scores_vals) + brute_iou_std = np.std(brute_iou_scores_vals) - brute_iou_mean = np.mean(brute_iou_scores) - brute_iou_std = np.std(brute_iou_scores) - - recall_mean = np.mean(recall_scores) - recall_std = np.std(recall_scores) + recall_scores_vals = list(recall_scores.values()) + recall_mean = np.mean(recall_scores_vals) + recall_std = np.std(recall_scores_vals) - iou_mean = np.mean(iou_scores) - iou_std = np.std(iou_scores) + iou_scores_vals = list(iou_scores.values()) + iou_mean = np.mean(iou_scores_vals) + iou_std = np.std(iou_scores_vals) - precision_mean = np.mean(precision_scores) - precision_std = np.std(precision_scores) + precision_scores_vals = list(precision_scores.values()) + precision_mean = np.mean(precision_scores_vals) + precision_std = np.std(precision_scores_vals) # print("Recall scores: ", recall_scores) # print("Precision scores: ", precision_scores) diff --git a/chunking_evaluation/evaluation_framework/dataset_evaluation.py b/chunking_evaluation/evaluation_framework/dataset_evaluation.py new file mode 100644 index 0000000..164e0b3 --- /dev/null +++ b/chunking_evaluation/evaluation_framework/dataset_evaluation.py @@ -0,0 +1,40 @@ +from enum import Enum + +from .general_evaluation import GeneralEvaluation + + +class Dataset(Enum): + CHATLOGS = 'chatlogs' + FINANCE = 'finance' + PUBMED = 'pubmed' + STATE_OF_THE_UNION = 'state_of_the_union' + WIKITEXTS = 'wikitexts' + + +class DatasetEvaluation(GeneralEvaluation): + + def __init__(self, datasets: list[Dataset], chroma_db_path=None): + # edge cases handling + if len(datasets) == 0: + raise ValueError('The `datasets` list argument is empty') + + for dataset in datasets: + if not isinstance(dataset, Dataset): + raise TypeError('The `datasets` parameter must be a list of Dataset enum instance') + + # maps enums to their values + self._datasets: set[str] = set(map(lambda item: item.value, datasets)) + + super().__init__(chroma_db_path=chroma_db_path) + + def _load_questions_df(self): + """Filters the `corpus_list` and `questions_df` to include only the provided dataset values.""" + super()._load_questions_df() + + # filter questions + filtered_questions = self.questions_df[self.questions_df['corpus_id'].isin(self._datasets)] + self.questions_df = filtered_questions + + # filter corpus list + filtered_corpus = list(filter(lambda item: item in self._datasets, self.corpus_list)) + self.corpus_list = filtered_corpus diff --git a/setup.py b/setup.py index 843f918..2ec2ae3 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,8 @@ "python-Levenshtein", "openai", "anthropic", - "attrs" + "attrs", + "pytest" ], author="Brandon A. Smith", author_email="brandonsmithpmpuk@gmail.com", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dataset_evaluation.py b/tests/test_dataset_evaluation.py new file mode 100644 index 0000000..81e10db --- /dev/null +++ b/tests/test_dataset_evaluation.py @@ -0,0 +1,74 @@ +import pytest +import pandas as pd + +from chunking_evaluation.evaluation_framework.dataset_evaluation import DatasetEvaluation, Dataset + +QUESTIONS_DF_PATH = './chunking_evaluation/evaluation_framework/general_evaluation_data/questions_df.csv' + + +def test_does_not_accept_empty_list(): + with pytest.raises(ValueError, match='The `datasets` list argument is empty'): + DatasetEvaluation(datasets=[]) + + +def test_accepts_only_dataset_enum_values(): + with pytest.raises(TypeError, match='The `datasets` parameter must be a list of Dataset enum instance'): + DatasetEvaluation( + datasets=['chatlogs', 'finance'] + ) + + +def test_maps_enum_values_to_datasets(): + dataset_eval = DatasetEvaluation( + datasets=[ + Dataset.FINANCE, + Dataset.CHATLOGS + ] + ) + + assert dataset_eval._datasets == {'finance', 'chatlogs'} + + +def test_ignores_duplicate_dataset_names(): + dataset_eval = DatasetEvaluation( + datasets=[ + Dataset.FINANCE, + Dataset.FINANCE, + Dataset.FINANCE, + Dataset.CHATLOGS, + Dataset.CHATLOGS, + ] + ) + + assert len(dataset_eval._datasets) == 2 + assert dataset_eval._datasets == {'finance', 'chatlogs'} + + +def test_filters_corpus_list_based_on_datasets(): + dataset_eval = DatasetEvaluation( + datasets=[Dataset.PUBMED, Dataset.WIKITEXTS] + ) + + assert sorted(dataset_eval.corpus_list) == sorted(['pubmed', 'wikitexts']) + + +def test_filters_questions_df_based_on_datasets(): + dataset_eval = DatasetEvaluation( + datasets=[Dataset.STATE_OF_THE_UNION, Dataset.FINANCE] + ) + + questions_df = pd.read_csv(QUESTIONS_DF_PATH) + filtered_df = questions_df[questions_df['corpus_id'].isin(['state_of_the_union', 'finance'])] + + assert len(filtered_df) == len(dataset_eval.questions_df) + + +def test_loads_all_questions_df_and_corresponding_corpus(): + dataset_eval = DatasetEvaluation( + datasets=[Dataset.STATE_OF_THE_UNION, Dataset.FINANCE, Dataset.CHATLOGS, Dataset.PUBMED, Dataset.WIKITEXTS] + ) + assert sorted(dataset_eval.corpus_list) == sorted( + ['state_of_the_union', 'finance', 'chatlogs', 'pubmed', 'wikitexts']) + + questions_df = pd.read_csv(QUESTIONS_DF_PATH) + assert len(questions_df) == len(dataset_eval.questions_df)