Skip to content

Commit 7ded1ef

Browse files
committed
feat: add --dataset flag so custom metrics can be forced to run on only specific datasets
1 parent 061ae36 commit 7ded1ef

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

src/openlayer/lib/core/metrics.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,19 @@ def _parse_args(self) -> None:
9595
"provided (assuming location is metrics/metric_name/run.py)."
9696
),
9797
)
98+
parser.add_argument(
99+
"--dataset",
100+
type=str,
101+
required=False,
102+
default="",
103+
help="The name of the dataset to compute the metric on. Runs on all "
104+
"datasets if not provided.",
105+
)
98106

99107
# Parse the arguments
100108
args = parser.parse_args()
101109
self.config_path = args.config_path
110+
self.dataset_name = args.dataset
102111
self.likely_dir = os.path.dirname(os.path.dirname(os.getcwd()))
103112

104113
def _load_openlayer_json(self) -> None:
@@ -122,6 +131,12 @@ def _load_datasets(self) -> None:
122131
model = self.config["model"]
123132
datasets_list = self.config["datasets"]
124133
dataset_names = [dataset["name"] for dataset in datasets_list]
134+
if self.dataset_name:
135+
if self.dataset_name not in dataset_names:
136+
raise ValueError(
137+
f"Dataset {self.dataset_name} not found in the openlayer.json."
138+
)
139+
dataset_names = [self.dataset_name]
125140
output_directory = model["outputDirectory"]
126141
# Read the outputs directory for dataset folders. For each, load
127142
# the config.json and the dataset.json files into a dict and a dataframe

0 commit comments

Comments
 (0)