Skip to content

Commit a4a3e56

Browse files
authored
fix: v1 to v2 dataset (#1275)
fixes: #1271
1 parent 3076f50 commit a4a3e56

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/ragas/evaluation.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@
3434
)
3535
from ragas.metrics.critique import AspectCritique
3636
from ragas.run_config import RunConfig
37-
from ragas.utils import REQUIRED_COLS_v1, get_feature_language, safe_nanmean
37+
from ragas.utils import (
38+
convert_v1_to_v2_dataset,
39+
convert_v2_to_v1_dataset,
40+
get_feature_language,
41+
safe_nanmean,
42+
)
3843
from ragas.validation import (
3944
remap_column_names,
4045
validate_required_columns,
@@ -164,7 +169,7 @@ def evaluate(
164169
# remap column names from the dataset
165170
v1_input = True
166171
dataset = remap_column_names(dataset, column_map)
167-
dataset = remap_column_names(dataset, REQUIRED_COLS_v1)
172+
dataset = convert_v1_to_v2_dataset(dataset)
168173
# validation
169174
dataset = EvaluationDataset.from_list(dataset.to_list())
170175

@@ -310,8 +315,7 @@ def evaluate(
310315
# convert to v.1 dataset
311316
dataset = dataset.to_hf_dataset()
312317
if v1_input:
313-
cols = {k: v for v, k in REQUIRED_COLS_v1.items()}
314-
dataset = remap_column_names(dataset, cols)
318+
dataset = convert_v2_to_v1_dataset(dataset)
315319

316320
cost_cb = ragas_callbacks["cost_cb"] if "cost_cb" in ragas_callbacks else None
317321
result = Result(

src/ragas/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import lru_cache
88

99
import numpy as np
10+
from datasets import Dataset
1011

1112
if t.TYPE_CHECKING:
1213
from ragas.metrics.base import Metric
@@ -197,3 +198,13 @@ def get_required_columns_v1(metric: Metric):
197198
def convert_row_v1_to_v2(row: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
198199
required_cols_v2 = {k: v for v, k in REQUIRED_COLS_v1.items()}
199200
return {required_cols_v2[k]: v for k, v in row.items() if k in required_cols_v2}
201+
202+
203+
def convert_v1_to_v2_dataset(dataset: Dataset) -> Dataset:
204+
columns_map = {v: k for k, v in REQUIRED_COLS_v1.items() if v in dataset.features}
205+
return dataset.rename_columns(columns_map)
206+
207+
208+
def convert_v2_to_v1_dataset(dataset: Dataset) -> Dataset:
209+
columns_map = {k: v for k, v in REQUIRED_COLS_v1.items() if k in dataset.features}
210+
return dataset.rename_columns(columns_map)

0 commit comments

Comments
 (0)