diff --git a/cais/agent.py b/cais/agent.py index dee2aff..3577b5c 100644 --- a/cais/agent.py +++ b/cais/agent.py @@ -234,7 +234,7 @@ def execute_method(self, query=None, remove_cleaned=True): if self.cleaned_dataset_path and remove_cleaned: if isinstance(self.load_dataset(cleaned=True), pd.DataFrame): - os.remove(self.cleaned_dataset_path) + # os.remove(self.cleaned_dataset_path) self.cleaned_dataset_path=None logger.info("Succesfully Removed Cleaned Dataset.") @@ -341,7 +341,7 @@ def run_causal_analysis(query: str, dataset_path: str, instrument_hints=input_parsing_result["extracted_variables"].get("instruments_mentioned") ) - query_interpreter_output = query_interpreter_tool.func(query_info=query_info, dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query = input_parsing_result["original_query"]).variables + query_interpreter_output = query_interpreter_tool.func(dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).variables # print('LOG RESULTS') # print(input_parsing_result['extracted_variables']) @@ -447,13 +447,23 @@ def run_causal_analysis(query: str, dataset_path: str, dataset_description=input_parsing_result["dataset_description"], original_query = input_parsing_result["original_query"]) result = explainer_output + + # include query_info in result + result["query_info"] = { + "query_text": input_parsing_result["original_query"], + "potential_treatments": input_parsing_result["extracted_variables"].get("treatment"), + "potential_outcomes": input_parsing_result["extracted_variables"].get("outcome"), + "covariates_hints": input_parsing_result["extracted_variables"].get("covariates_mentioned"), + "instrument_hints": input_parsing_result["extracted_variables"].get("instruments_mentioned") + } + #result['results']['results']["method_used"] = method_validator_output.get('method') logger.debug(result) logger.info("Causal analysis run finished.") # Remove the cleaned csv - logger.info("Removing cleaned csv.") - os.remove(cleaned_path) + # logger.info("Removing cleaned csv.") + # os.remove(cleaned_path) # Ensure result is a dict and extract the 'output' part if isinstance(result, dict): diff --git a/cais/components/dataset_cleaner.py b/cais/components/dataset_cleaner.py index e8531ce..b1f184d 100644 --- a/cais/components/dataset_cleaner.py +++ b/cais/components/dataset_cleaner.py @@ -279,6 +279,10 @@ def run_cleaning_stage(dataset_path: str, if ("Traceback" in stderr_all) or ("Error" in stderr_all): report.append("\n⚠️ LLM pipeline produced errors. Check stderr; artifacts may be missing or partial.") + if not os.path.exists(cleaned_path): + report.append(f"\n⚠️ Cleaned file not found at expected path; falling back to original dataset.") + cleaned_path = dataset_path + return { "cleaned_dataset_path": cleaned_path, "cleaning_report_md": "\n".join(report), diff --git a/run_cais_new.py b/run_cais_new.py index 93854df..dde4dba 100644 --- a/run_cais_new.py +++ b/run_cais_new.py @@ -161,14 +161,26 @@ def main(): print('Starting run!') - cais = CausalAgent() - - cais.run_analysis( - query=row["natural_language_query"], + cais = CausalAgent( dataset_path=data_path, dataset_description=desc, - use_decision_tree=True ) + + res = cais.run_analysis( + query=row["natural_language_query"], + ) + + # write result to file + formatted_result = { + "query": row["natural_language_query"], + "method": row["method"], + "answer": row["answer"], + "dataset_description": desc, + "dataset_path": data_path, + "keywords": row.get("keywords", "Causality, Average treatment effect"), + "final_result": res + } + file.write(json.dumps({idx: formatted_result}) + "\n") except Exception as e: logging.error(f"[row {idx}] Error: {e}")