Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cais/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main(argv: Optional[list[str]] = None) -> None:
os.environ["LLM_MODEL"] = args.llm_name
if getattr(args, "llm_provider", None):
os.environ["LLM_PROVIDER"] = args.llm_provider
if getattr(args, "use_llm_rule_engine", False):
if getattr(args, "use_llm_rule_engine", False): # TODO: Change this to use decision tree in default
os.environ["CAIS_USE_LLM_RULE_ENGINE"] = "1"

if args.command == "run":
Expand Down
162 changes: 162 additions & 0 deletions evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import os
import json
import argparse
import numpy as np
import pandas as pd


def standardize_method_name(method):
"""
Standardize method names to a coarse causal family.
"""

if method is None or not isinstance(method, str):
return np.nan

m = method.lower().strip()

# Explicit failures
if any(x in m for x in ["null", "na", "n/a", "none"]):
return np.nan

# Frontdoor FIRST (important)
if "frontdoor" in m or "front door" in m:
return "fd"

# IV
if any(x in m for x in ["instrument", "encouragement", "2sls", "iv"]):
return "iv"

# RDD
if any(x in m for x in ["discontinuity", "rdd", "fuzzy"]):
return "rdd"

# GLMs
if any(x in m for x in ["logistic", "probit", "logit", "glm"]):
return "glm"

# Observational adjustment
if any(x in m for x in ["weighting", "ipw", "propensity", "matching", "observational"]):
return "observational"

# OLS / means / RCT-style
if any(x in m for x in ["linear", "means", "ordinary", "ols", "wls", "rct"]):
return "ols"

# DiD / panels
if any(x in m for x in ["difference", "did", "fixed effects", "panel"]):
return "did"

return "other"


def match_method(gt_method, pred_method, split):
"""
Method match based on standardized causal families.

Returns:
bool
"""
gt = standardize_method_name(gt_method)
pr = standardize_method_name(pred_method)

# If either side failed to produce a method → incorrect
if pd.isna(gt) or pd.isna(pr):
return False

# Exact family match
if gt == pr:
return True

# -------------------------------
# Controlled relaxations by split
# -------------------------------

# === REAL ===
if split == "real":
# GLM is acceptable when GT is OLS (common misuse)
if gt == "ols" and pr == "glm":
return True

# Observational ≠ OLS (do NOT relax)
return False

# === QR ===
if split == "qr":
# QR ground truth is loose by design
if gt in {"ols", "observational"} and pr in {"ols", "observational"}:
return True
return False

# === SYNTHETIC ===
if split == "synthetic":
# RCT data: OLS is acceptable
if gt == "rct" and pr == "ols":
return True

# Observational ≈ OLS in synthetic benchmarks
if gt == "observational" and pr == "ols":
return True

return False

return False


def evaluate_result(fname):

with open(fname, "r") as f:
data = [json.loads(line) for line in f if line.strip()]

if 'qr' in fname:
split = 'qr'
elif 'real' in fname:
split = 'real'
elif 'synthetic' in fname:
split = 'synthetic'
else:
raise ValueError(f"Unknown split for file: {fname}")

correct_method = 0
for i in range(0, len(data)):

try:
entry = next(iter(data[i].values()))
gt_method = entry['method']
pred_method = entry['final_result']['method']

# gt_answer = entry['answer']
# pred_answer = entry['final_result']['causal_effect']

ans = match_method(gt_method, pred_method, split=split)

if ans == True:
correct_method += 1

except Exception as e:
print(f"[ERROR] Failed to evaluate entry {i}: {e}")
continue

# Print result name and method accuracy
print("Method Selection Accuracy: ", (correct_method / len(data)) * 100)
print('--------')

def main():

parser = argparse.ArgumentParser()
parser.add_argument("--results-dir", required=True)
args = parser.parse_args()

os.makedirs(args.results_dir, exist_ok=True)

# Load json files from args.out_dir
for fname in os.listdir(args.results_dir):
if not fname.endswith(".json"):
print(f"[WARNING] Skipping non-json file: {fname}")

print(f"[INFO] Evaluating file: {fname}")
evaluate_result(os.path.join(args.results_dir, fname))


if __name__ == "__main__":
main()