Skip to content

Commit 41142fb

Browse files
committed
Fixes bug with base model
1 parent 539d05e commit 41142fb

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

openlayer/model_runners/base_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@ def run_from_cli(self):
4848

4949
def batch(self, dataset_path: str, output_dir: str):
5050
# Load the dataset into a pandas DataFrame
51+
fmt = dataset_path.split(".")[-1]
5152
if dataset_path.endswith(".csv"):
5253
df = pd.read_csv(dataset_path)
5354
elif dataset_path.endswith(".json"):
5455
df = pd.read_json(dataset_path, orient="records")
56+
else:
57+
raise ValueError("Unsupported format. Please choose 'csv' or 'json'.")
5558

5659
# Call the model's run_batch method, passing in the DataFrame
5760
output_df, config = self.run_batch_from_df(df)
58-
self.write_output_to_directory(output_df, config, output_dir)
61+
self.write_output_to_directory(output_df, config, output_dir, fmt=fmt)
5962

6063
def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
6164
"""Function that runs the model and returns the result."""

0 commit comments

Comments
 (0)