Skip to content

Commit 36c56e4

Browse files
committed
feat: add new columns to dataset when running custom metrics
1 parent 3bc5418 commit 36c56e4

File tree

1 file changed

+57
-5
lines changed

1 file changed

+57
-5
lines changed

src/openlayer/lib/core/metrics.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import json
88
import os
99
from dataclasses import asdict, dataclass, field
10-
from typing import Any, Dict, List, Optional, Union
10+
from typing import Any, Dict, List, Optional, Union, Set
1111

1212
import pandas as pd
1313

@@ -25,6 +25,9 @@ class MetricReturn:
2525
meta: Dict[str, Any] = field(default_factory=dict)
2626
"""Any useful metadata in a JSON serializable dict."""
2727

28+
added_cols: Set[str] = field(default_factory=set)
29+
"""Columns added to the dataset."""
30+
2831

2932
@dataclass
3033
class Dataset:
@@ -42,6 +45,12 @@ class Dataset:
4245
output_path: str
4346
"""The path to the dataset outputs."""
4447

48+
data_format: str
49+
"""The format of the written dataset. E.g. 'csv' or 'json'."""
50+
51+
added_cols: Set[str] = field(default_factory=set)
52+
"""Columns added to the dataset."""
53+
4554

4655
class MetricRunner:
4756
"""A class to run a list of metrics."""
@@ -68,6 +77,9 @@ def run_metrics(self, metrics: List[BaseMetric]) -> None:
6877

6978
self._compute_metrics(metrics)
7079

80+
# Write the updated datasets to the output location
81+
self._write_updated_datasets_to_output()
82+
7183
def _parse_args(self) -> None:
7284
parser = argparse.ArgumentParser(description="Compute custom metrics.")
7385
parser.add_argument(
@@ -124,13 +136,21 @@ def _load_datasets(self) -> None:
124136
# Load the dataset into a pandas DataFrame
125137
if os.path.exists(os.path.join(dataset_path, "dataset.csv")):
126138
dataset_df = pd.read_csv(os.path.join(dataset_path, "dataset.csv"))
139+
data_format = "csv"
127140
elif os.path.exists(os.path.join(dataset_path, "dataset.json")):
128141
dataset_df = pd.read_json(os.path.join(dataset_path, "dataset.json"), orient="records")
142+
data_format = "json"
129143
else:
130144
raise ValueError(f"No dataset found in {dataset_folder}.")
131145

132146
datasets.append(
133-
Dataset(name=dataset_folder, config=dataset_config, df=dataset_df, output_path=dataset_path)
147+
Dataset(
148+
name=dataset_folder,
149+
config=dataset_config,
150+
df=dataset_df,
151+
output_path=dataset_path,
152+
data_format=data_format,
153+
)
134154
)
135155
else:
136156
raise ValueError("No model found in the openlayer.json file. Cannot compute metric.")
@@ -148,6 +168,31 @@ def _compute_metrics(self, metrics: List[BaseMetric]) -> None:
148168
continue
149169
metric.compute(self.datasets)
150170

171+
def _write_updated_datasets_to_output(self) -> None:
172+
"""Write the updated datasets to the output location."""
173+
for dataset in self.datasets:
174+
if dataset.added_cols:
175+
self._write_updated_dataset_to_output(dataset)
176+
177+
def _write_updated_dataset_to_output(self, dataset: Dataset) -> None:
178+
"""Write the updated dataset to the output location."""
179+
180+
# Determine the filename based on the dataset name and format
181+
filename = f"dataset.{dataset.data_format}"
182+
data_path = os.path.join(dataset.output_path, filename)
183+
184+
# TODO: Read the dataset again and only include the added columns
185+
186+
# Write the DataFrame to the file based on the specified format
187+
if dataset.data_format == "csv":
188+
dataset.df.to_csv(data_path, index=False)
189+
elif dataset.data_format == "json":
190+
dataset.df.to_json(data_path, orient="records", indent=4, index=False)
191+
else:
192+
raise ValueError("Unsupported format. Please choose 'csv' or 'json'.")
193+
194+
print(f"Updated dataset {dataset.name} written to {data_path}")
195+
151196

152197
class BaseMetric(abc.ABC):
153198
"""Interface for the Base metric.
@@ -163,7 +208,7 @@ def key(self) -> str:
163208
def compute(self, datasets: List[Dataset]) -> None:
164209
"""Compute the metric on the model outputs."""
165210
for dataset in datasets:
166-
metric_return = self.compute_on_dataset(dataset.config, dataset.df)
211+
metric_return = self.compute_on_dataset(dataset)
167212
metric_value = metric_return.value
168213
if metric_return.unit:
169214
metric_value = f"{metric_value} {metric_return.unit}"
@@ -172,8 +217,12 @@ def compute(self, datasets: List[Dataset]) -> None:
172217
output_dir = os.path.join(dataset.output_path, "metrics")
173218
self._write_metric_return_to_file(metric_return, output_dir)
174219

220+
# Add the added columns to the dataset
221+
if metric_return.added_cols:
222+
dataset.added_cols.update(metric_return.added_cols)
223+
175224
@abc.abstractmethod
176-
def compute_on_dataset(self, config: dict, df: pd.DataFrame) -> MetricReturn:
225+
def compute_on_dataset(self, dataset: Dataset) -> MetricReturn:
177226
"""Compute the metric on a specific dataset."""
178227
pass
179228

@@ -183,6 +232,9 @@ def _write_metric_return_to_file(self, metric_return: MetricReturn, output_dir:
183232
# Create the directory if it doesn't exist
184233
os.makedirs(output_dir, exist_ok=True)
185234

235+
# Turn the metric return to a dict
236+
metric_return_dict = asdict(metric_return)
237+
186238
with open(os.path.join(output_dir, f"{self.key}.json"), "w", encoding="utf-8") as f:
187-
json.dump(asdict(metric_return), f, indent=4)
239+
json.dump(metric_return_dict, f, indent=4)
188240
print(f"Metric ({self.key}) value written to {output_dir}/{self.key}.json")

0 commit comments

Comments
 (0)