Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions cais/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def execute_method(self, query=None, remove_cleaned=True):

if self.cleaned_dataset_path and remove_cleaned:
if isinstance(self.load_dataset(cleaned=True), pd.DataFrame):
os.remove(self.cleaned_dataset_path)
# os.remove(self.cleaned_dataset_path)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commenting this out is fine since the user's can manually remove the cleaned files, but we may want to instead save it to a cleaned subset folder for easier removal if we don't want to do it already.

self.cleaned_dataset_path=None
logger.info("Succesfully Removed Cleaned Dataset.")

Expand Down Expand Up @@ -341,7 +341,7 @@ def run_causal_analysis(query: str, dataset_path: str,
instrument_hints=input_parsing_result["extracted_variables"].get("instruments_mentioned")
)

query_interpreter_output = query_interpreter_tool.func(query_info=query_info, dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query = input_parsing_result["original_query"]).variables
query_interpreter_output = query_interpreter_tool.func(dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).variables

# print('LOG RESULTS')
# print(input_parsing_result['extracted_variables'])
Expand Down Expand Up @@ -447,13 +447,23 @@ def run_causal_analysis(query: str, dataset_path: str,
dataset_description=input_parsing_result["dataset_description"],
original_query = input_parsing_result["original_query"])
result = explainer_output

# include query_info in result
result["query_info"] = {
"query_text": input_parsing_result["original_query"],
"potential_treatments": input_parsing_result["extracted_variables"].get("treatment"),
"potential_outcomes": input_parsing_result["extracted_variables"].get("outcome"),
"covariates_hints": input_parsing_result["extracted_variables"].get("covariates_mentioned"),
"instrument_hints": input_parsing_result["extracted_variables"].get("instruments_mentioned")
}

#result['results']['results']["method_used"] = method_validator_output.get('method')
logger.debug(result)
logger.info("Causal analysis run finished.")

# Remove the cleaned csv
logger.info("Removing cleaned csv.")
os.remove(cleaned_path)
# logger.info("Removing cleaned csv.")
# os.remove(cleaned_path)

# Ensure result is a dict and extract the 'output' part
if isinstance(result, dict):
Expand Down
4 changes: 4 additions & 0 deletions cais/components/dataset_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def run_cleaning_stage(dataset_path: str,
if ("Traceback" in stderr_all) or ("Error" in stderr_all):
report.append("\n⚠️ LLM pipeline produced errors. Check stderr; artifacts may be missing or partial.")

if not os.path.exists(cleaned_path):
report.append(f"\n⚠️ Cleaned file not found at expected path; falling back to original dataset.")
cleaned_path = dataset_path

return {
"cleaned_dataset_path": cleaned_path,
"cleaning_report_md": "\n".join(report),
Expand Down
22 changes: 17 additions & 5 deletions run_cais_new.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, my mistake for forgetting to update our CLI.

Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,26 @@ def main():

print('Starting run!')

cais = CausalAgent()

cais.run_analysis(
query=row["natural_language_query"],
cais = CausalAgent(
dataset_path=data_path,
dataset_description=desc,
use_decision_tree=True
)

res = cais.run_analysis(
query=row["natural_language_query"],
)

# write result to file
formatted_result = {
"query": row["natural_language_query"],
"method": row["method"],
"answer": row["answer"],
"dataset_description": desc,
"dataset_path": data_path,
"keywords": row.get("keywords", "Causality, Average treatment effect"),
"final_result": res
}
file.write(json.dumps({idx: formatted_result}) + "\n")

except Exception as e:
logging.error(f"[row {idx}] Error: {e}")
Expand Down