7
7
import json
8
8
import os
9
9
from dataclasses import asdict , dataclass , field
10
- from typing import Any , Dict , List , Optional , Union
10
+ from typing import Any , Dict , List , Optional , Union , Set
11
11
12
12
import pandas as pd
13
13
@@ -25,6 +25,9 @@ class MetricReturn:
25
25
meta : Dict [str , Any ] = field (default_factory = dict )
26
26
"""Any useful metadata in a JSON serializable dict."""
27
27
28
+ added_cols : Set [str ] = field (default_factory = set )
29
+ """Columns added to the dataset."""
30
+
28
31
29
32
@dataclass
30
33
class Dataset :
@@ -42,6 +45,12 @@ class Dataset:
42
45
output_path : str
43
46
"""The path to the dataset outputs."""
44
47
48
+ data_format : str
49
+ """The format of the written dataset. E.g. 'csv' or 'json'."""
50
+
51
+ added_cols : Set [str ] = field (default_factory = set )
52
+ """Columns added to the dataset."""
53
+
45
54
46
55
class MetricRunner :
47
56
"""A class to run a list of metrics."""
@@ -68,6 +77,9 @@ def run_metrics(self, metrics: List[BaseMetric]) -> None:
68
77
69
78
self ._compute_metrics (metrics )
70
79
80
+ # Write the updated datasets to the output location
81
+ self ._write_updated_datasets_to_output ()
82
+
71
83
def _parse_args (self ) -> None :
72
84
parser = argparse .ArgumentParser (description = "Compute custom metrics." )
73
85
parser .add_argument (
@@ -124,13 +136,21 @@ def _load_datasets(self) -> None:
124
136
# Load the dataset into a pandas DataFrame
125
137
if os .path .exists (os .path .join (dataset_path , "dataset.csv" )):
126
138
dataset_df = pd .read_csv (os .path .join (dataset_path , "dataset.csv" ))
139
+ data_format = "csv"
127
140
elif os .path .exists (os .path .join (dataset_path , "dataset.json" )):
128
141
dataset_df = pd .read_json (os .path .join (dataset_path , "dataset.json" ), orient = "records" )
142
+ data_format = "json"
129
143
else :
130
144
raise ValueError (f"No dataset found in { dataset_folder } ." )
131
145
132
146
datasets .append (
133
- Dataset (name = dataset_folder , config = dataset_config , df = dataset_df , output_path = dataset_path )
147
+ Dataset (
148
+ name = dataset_folder ,
149
+ config = dataset_config ,
150
+ df = dataset_df ,
151
+ output_path = dataset_path ,
152
+ data_format = data_format ,
153
+ )
134
154
)
135
155
else :
136
156
raise ValueError ("No model found in the openlayer.json file. Cannot compute metric." )
@@ -148,6 +168,31 @@ def _compute_metrics(self, metrics: List[BaseMetric]) -> None:
148
168
continue
149
169
metric .compute (self .datasets )
150
170
171
+ def _write_updated_datasets_to_output (self ) -> None :
172
+ """Write the updated datasets to the output location."""
173
+ for dataset in self .datasets :
174
+ if dataset .added_cols :
175
+ self ._write_updated_dataset_to_output (dataset )
176
+
177
+ def _write_updated_dataset_to_output (self , dataset : Dataset ) -> None :
178
+ """Write the updated dataset to the output location."""
179
+
180
+ # Determine the filename based on the dataset name and format
181
+ filename = f"dataset.{ dataset .data_format } "
182
+ data_path = os .path .join (dataset .output_path , filename )
183
+
184
+ # TODO: Read the dataset again and only include the added columns
185
+
186
+ # Write the DataFrame to the file based on the specified format
187
+ if dataset .data_format == "csv" :
188
+ dataset .df .to_csv (data_path , index = False )
189
+ elif dataset .data_format == "json" :
190
+ dataset .df .to_json (data_path , orient = "records" , indent = 4 , index = False )
191
+ else :
192
+ raise ValueError ("Unsupported format. Please choose 'csv' or 'json'." )
193
+
194
+ print (f"Updated dataset { dataset .name } written to { data_path } " )
195
+
151
196
152
197
class BaseMetric (abc .ABC ):
153
198
"""Interface for the Base metric.
@@ -163,7 +208,7 @@ def key(self) -> str:
163
208
def compute (self , datasets : List [Dataset ]) -> None :
164
209
"""Compute the metric on the model outputs."""
165
210
for dataset in datasets :
166
- metric_return = self .compute_on_dataset (dataset . config , dataset . df )
211
+ metric_return = self .compute_on_dataset (dataset )
167
212
metric_value = metric_return .value
168
213
if metric_return .unit :
169
214
metric_value = f"{ metric_value } { metric_return .unit } "
@@ -172,8 +217,12 @@ def compute(self, datasets: List[Dataset]) -> None:
172
217
output_dir = os .path .join (dataset .output_path , "metrics" )
173
218
self ._write_metric_return_to_file (metric_return , output_dir )
174
219
220
+ # Add the added columns to the dataset
221
+ if metric_return .added_cols :
222
+ dataset .added_cols .update (metric_return .added_cols )
223
+
175
224
@abc .abstractmethod
176
- def compute_on_dataset (self , config : dict , df : pd . DataFrame ) -> MetricReturn :
225
+ def compute_on_dataset (self , dataset : Dataset ) -> MetricReturn :
177
226
"""Compute the metric on a specific dataset."""
178
227
pass
179
228
@@ -183,6 +232,9 @@ def _write_metric_return_to_file(self, metric_return: MetricReturn, output_dir:
183
232
# Create the directory if it doesn't exist
184
233
os .makedirs (output_dir , exist_ok = True )
185
234
235
+ # Turn the metric return to a dict
236
+ metric_return_dict = asdict (metric_return )
237
+
186
238
with open (os .path .join (output_dir , f"{ self .key } .json" ), "w" , encoding = "utf-8" ) as f :
187
- json .dump (asdict ( metric_return ) , f , indent = 4 )
239
+ json .dump (metric_return_dict , f , indent = 4 )
188
240
print (f"Metric ({ self .key } ) value written to { output_dir } /{ self .key } .json" )
0 commit comments