diff --git a/src/openlayer/lib/core/base_model.py b/src/openlayer/lib/core/base_model.py index 306526ff..e847e2bd 100644 --- a/src/openlayer/lib/core/base_model.py +++ b/src/openlayer/lib/core/base_model.py @@ -42,7 +42,9 @@ class OpenlayerModel(abc.ABC): def run_from_cli(self) -> None: """Run the model from the command line.""" parser = argparse.ArgumentParser(description="Run data through a model.") - parser.add_argument("--dataset-path", type=str, required=True, help="Path to the dataset") + parser.add_argument( + "--dataset-path", type=str, required=True, help="Path to the dataset" + ) parser.add_argument( "--output-dir", type=str, @@ -85,7 +87,9 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: # Filter row_dict to only include keys that are valid parameters # for the 'run' method row_dict = row.to_dict() - filtered_kwargs = {k: v for k, v in row_dict.items() if k in run_signature.parameters} + filtered_kwargs = { + k: v for k, v in row_dict.items() if k in run_signature.parameters + } # Call the run method with filtered kwargs output = self.run(**filtered_kwargs) @@ -108,7 +112,8 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: if "tokens" in processed_trace: df.at[index, "tokens"] = processed_trace["tokens"] if "context" in processed_trace: - df.at[index, "context"] = processed_trace["context"] + # Convert the context list to a string to avoid pandas issues + df.at[index, "context"] = json.dumps(processed_trace["context"]) config = { "outputColumnName": "output",