Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ CAIS consists of three main stages, powered by a **decision-tree-driven reasonin
- Traverses a predefined causal inference decision tree (Fig. B in paper).
- Maps detected properties to the most appropriate estimation method.

### **Stage 1.5: IV Discovery (Optional)**
If the **Instrumental Variable (IV)** method is selected and the `--iv_llm` pipeline is enabled:
1. **Hypothesis Generation**: The LLM hypothesizes potential instruments based on dataset context and variable names.
2. **Confounder Mining**: Identifies potential confounders that might violate the independence or exclusion restrictions.
3. **Critic Validation**: Uses specialized LLM "critics" (Exclusion, Independence) to reason about the validity of each candidate instrument.
4. **Final Selection**: Selects the most robust instrument for the estimation stage.

---

### **Stage 2: Causal Inference Execution**
Expand Down Expand Up @@ -152,22 +159,24 @@ All datasets used to evaluate CAIs and the baseline models are available in the
## Run
To execute CAIS, run
```python
python main/run_cais.py \
python run_cais.py \
--metadata_path {path_to_metadata} \
--data_dir {path_to_data_folder} \
--output_dir {output_folder} \
--output_name {output_filename} \
--llm_name {llm_name}
--llm_provider {llm_provider}
--llm_name {llm_name} \
--llm_provider {llm_provider} \
[--iv_llm]
```
Args:

* metadata_path (str): Path to the CSV file containing the queries, dataset descriptions, and data file names
* data_dir (str): Path to the folder containing the data in CSV format
* output_dir (str): Path to the folder where the output JSON results will be saved
* output_name (str): Name of the JSON file where the outputs will be saved
* llm_name (str): Name of the LLM to be used (e.g., 'gpt-4', 'claude-3', etc.)
* llm_name (str): Name of the LLM to be used (e.g., 'gpt-4o', 'claude-3-5-sonnet', etc.)
* llm_provider (str): Name of the LLM service provider (e.g., 'openai', 'anthropic', 'together', etc.)
* iv_llm (bool, optional): If flag is present, enables the advanced experimental IV LLM pipeline for instrument discovery and validation.

A specific example,
```python
Expand Down
1 change: 1 addition & 0 deletions cais/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
input_parser_tool,
dataset_analyzer_tool,
query_interpreter_tool,
iv_discovery_tool,
method_selector_tool,
method_validator_tool,
method_executor_tool,
Expand Down
65 changes: 61 additions & 4 deletions cais/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cais.tools.input_parser_tool import input_parser_tool
from cais.tools.dataset_analyzer_tool import dataset_analyzer_tool
from cais.tools.query_interpreter_tool import query_interpreter_tool
from cais.tools.iv_discovery_tool import iv_discovery_tool
from cais.tools.method_selector_tool import method_selector_tool
from cais.tools.controls_selector_tool import controls_selector_tool
from cais.tools.method_validator_tool import method_validator_tool
Expand Down Expand Up @@ -49,6 +50,11 @@

# Set up basic logging
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


Comment on lines 51 to 60
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging.basicConfig(...) at import time reconfigures global logging for any library consumer and for the whole test suite. Prefer leaving logging configuration to the application entrypoint/CLI; here, set a module logger and emit logs without calling basicConfig (or gate it behind if __name__ == "__main__").

Suggested change
# Set up basic logging
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Set up module logger
logger = logging.getLogger(__name__)
if __name__ == "__main__":
# Configure basic logging only when this module is executed as a script
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

Copilot uses AI. Check for mistakes.
Expand All @@ -60,9 +66,11 @@ def __init__(
dataset_description: Optional[str] = None, # Description of the dataset
model_name: Optional[str] = None,
provider: Optional[str] = None,
use_iv_pipeline: bool = False,
):
# Query not passed to constructor or saved so we can rerun different queries on the same dataset

self.use_iv_pipeline = use_iv_pipeline
self.llm_info = {
'model_name' : model_name,
'provider' : provider
Expand Down Expand Up @@ -121,6 +129,7 @@ def analyse_dataset(self, query=None):
dataset_path=self.dataset_path,
dataset_description=self.dataset_description,
original_query=query,
use_iv_pipeline=self.use_iv_pipeline,
llm=self.llm
).analysis_results

Expand Down Expand Up @@ -153,6 +162,23 @@ def select_method(self, query=None, llm_decision=True):
self.selected_method = self.method_info.selected_method
return self.selected_method

def discover_instruments(self, query=None):
query = self.checkq(query)

iv_discovery_output = iv_discovery_tool.func(
variables=self.variables,
dataset_analysis=self.dataset_analysis,
dataset_description=self.dataset_description,
original_query=query
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IV discovery is invoked via iv_discovery_tool.func(...) without passing the agent’s configured LLM/provider/model. The current IV discovery component creates its own default LLM client, which can lead to inconsistent behavior (and failures) when the agent is configured for a non-default provider/model. Consider passing self.llm through the tool/component so all steps share the same client/config.

Suggested change
original_query=query
original_query=query,
llm=self.llm,

Copilot uses AI. Check for mistakes.
)

if hasattr(iv_discovery_output, "model_dump"):
iv_discovery_output_dict = iv_discovery_output.model_dump()
else:
iv_discovery_output_dict = iv_discovery_output

self.variables = Variables(**iv_discovery_output_dict["variables"])
return self.variables

def validate_method(self, query=None):
'''
Expand All @@ -176,7 +202,7 @@ def select_controls(self, query=None) -> list:

query = self.checkq(query)

controls_selector_output = controls_selector_tool(
controls_selector_output = controls_selector_tool.func(
method_name=self.selected_method,
variables=self.variables,
dataset_analysis=self.dataset_analysis,
Expand All @@ -197,7 +223,16 @@ def clean_dataset(self, query=None):
original_query=query,
causal_method=self.selected_method
)
self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path", self.dataset_path)
self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path")

# Check if file was actually created/returned
if not self.cleaned_dataset_path or not os.path.exists(self.cleaned_dataset_path):
stderr = cleaning_output.get("stderr", "No stderr available.")
logger.error(f"Dataset cleaning failed to produce a file at {self.cleaned_dataset_path}. Stderr: {stderr}")
# Fallback to original dataset if cleaning failed but we want to attempt execution?
# Or raise error. Let's raise for now to be safe.
raise FileNotFoundError(f"Cleaned dataset NOT found at {self.cleaned_dataset_path}. Cleaning stderr: {stderr}")

return self.cleaned_dataset_path

def execute_method(self, query=None, remove_cleaned=True):
Expand Down Expand Up @@ -253,6 +288,11 @@ def run_analysis(self, query, llm_method_selection: Optional[bool] = True):
query=query,
llm_decision=llm_method_selection
)
if self.selected_method == INSTRUMENTAL_VARIABLE and self.use_iv_pipeline:
logger.info("Instrumental Variable method selected. Running IV Discovery...")
self.discover_instruments(
query=query
)
self.select_controls(
query=query
)
Expand All @@ -275,7 +315,8 @@ def run_analysis(self, query, llm_method_selection: Optional[bool] = True):
def run_causal_analysis(query: str, dataset_path: str,
dataset_description: Optional[str] = None,
api_key: Optional[str] = None,
use_method_validator: bool = True) -> Dict[str, Any]:
use_method_validator: bool = True,
use_iv_pipeline: bool = False) -> Dict[str, Any]:
"""
Run causal analysis on a dataset based on a user query.

Expand Down Expand Up @@ -331,7 +372,7 @@ def run_causal_analysis(query: str, dataset_path: str,
# This just returns query, dataset_path for the csv file and dataset_description
# and workflow state update but that's probably not needed

dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).analysis_results
dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"], use_iv_pipeline=use_iv_pipeline).analysis_results

query_info = QueryInfo(
query_text=input_parsing_result["original_query"],
Expand Down Expand Up @@ -402,6 +443,22 @@ def run_causal_analysis(query: str, dataset_path: str,
"suggestions": []
}
}

if method_name == INSTRUMENTAL_VARIABLE and use_iv_pipeline:
logger.info("Instrumental Variable method selected. Running IV Discovery...")
iv_discovery_output = iv_discovery_tool.func(
variables=query_interpreter_output,
dataset_analysis=dataset_analysis_result,
dataset_description=input_parsing_result["dataset_description"],
original_query=input_parsing_result["original_query"]
)
# update variables
if hasattr(iv_discovery_output, "model_dump"):
iv_discovery_output_dict = iv_discovery_output.model_dump()
else:
iv_discovery_output_dict = iv_discovery_output
query_interpreter_output = Variables(**iv_discovery_output_dict["variables"])

controls_selector_output = controls_selector_tool.func(
method_name=method_name,
variables=query_interpreter_output,
Expand Down
8 changes: 6 additions & 2 deletions cais/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(argv: Optional[list[str]] = None) -> None:
single.add_argument("--llm-provider", dest="llm_provider", default=None, help="LLM provider (openai, anthropic, together, gemini, deepseek)")
single.add_argument("--skip-method-validator", action="store_true", help="Skip method validation step")
single.add_argument("--use-llm-rule-engine", action="store_true", help="Use LLM-based method selection")
single.add_argument("--iv_llm", action="store_true", help="Use the new IV_LLM pipeline")

# Batch run compatible with existing metadata CSVs
batch = subparsers.add_parser("batch", help="Run batch analyses from a metadata CSV")
Expand All @@ -28,6 +29,7 @@ def main(argv: Optional[list[str]] = None) -> None:
batch.add_argument("--llm-provider", dest="llm_provider", default=None)
batch.add_argument("--skip-method-validator", action="store_true", help="Skip method validation step")
batch.add_argument("--use-llm-rule-engine", action="store_true", help="Use LLM-based method selection")
batch.add_argument("--iv_llm", action="store_true", help="Use the new IV_LLM pipeline")

args = parser.parse_args(argv)

Expand All @@ -47,7 +49,8 @@ def main(argv: Optional[list[str]] = None) -> None:
query=args.query,
dataset_path=args.dataset,
dataset_description=args.description,
use_method_validator=not args.skip_method_validator
use_method_validator=not args.skip_method_validator,
use_iv_pipeline=args.iv_llm
)
import json
print(json.dumps(result, indent=2))
Expand All @@ -67,7 +70,8 @@ def main(argv: Optional[list[str]] = None) -> None:
query=row.get("natural_language_query"),
dataset_path=data_path,
dataset_description=row.get("data_description"),
use_method_validator=not args.skip_method_validator
use_method_validator=not args.skip_method_validator,
use_iv_pipeline=args.iv_llm
)
results[idx] = {
"query": row.get("natural_language_query"),
Expand Down
38 changes: 35 additions & 3 deletions cais/components/dataset_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def analyze_dataset(
dataset_path: str,
llm_client: Optional[BaseChatModel] = None,
dataset_description: Optional[str] = None,
original_query: Optional[str] = None
original_query: Optional[str] = None,
use_iv_pipeline: bool = False
) -> Dict[str, Any]:
"""
Analyze a dataset to identify important characteristics for causal inference.
Expand Down Expand Up @@ -166,7 +167,8 @@ def analyze_dataset(
llm_client=llm_client,
potential_treatments=potential_variables.get("potential_treatments", []),
potential_outcomes=potential_variables.get("potential_outcomes", []),
dataset_description=dataset_description
dataset_description=dataset_description,
use_iv_pipeline=use_iv_pipeline
)

# Other analyses
Expand Down Expand Up @@ -637,7 +639,8 @@ def find_potential_instruments(
llm_client: Optional[BaseChatModel] = None,
potential_treatments: List[str] = None,
potential_outcomes: List[str] = None,
dataset_description: Optional[str] = None
dataset_description: Optional[str] = None,
use_iv_pipeline: bool = False
) -> List[Dict[str, Any]]:
"""
Find potential instrumental variables in the dataset, using LLM if available.
Expand All @@ -653,6 +656,35 @@ def find_potential_instruments(
Returns:
List of potential instrumental variables with their properties
"""
if use_iv_pipeline and potential_treatments and potential_outcomes:
try:
from cais.components.iv_discovery import IVDiscovery
logger.info("Using IV LLM Pipeline to discover instrumental variables")
discovery = IVDiscovery()
treatment = potential_treatments[0]
outcome = potential_outcomes[0]
context = f"Dataset Description: {dataset_description}" if dataset_description else ""

result = discovery.discover_instruments(treatment, outcome, context=context)
valid_ivs = result.get('valid_ivs', [])

iv_list = []
for iv in valid_ivs:
if iv in df.columns:
iv_list.append({
"variable": iv,
"reason": "Discovered and validated by IV LLM Pipeline critics",
"data_type": str(df[iv].dtype)
})
if iv_list:
logger.info(f"IV LLM Pipeline identified {len(iv_list)} valid instruments: {valid_ivs}")
return iv_list
else:
logger.warning("IV LLM Pipeline found no valid instruments, falling back to standard LLM or heuristic method")
except Exception as e:
logger.error(f"Error using IV LLM Pipeline: {e}", exc_info=True)
logger.info("Falling back to standard LLM or heuristic method")

# Try LLM approach if client is provided
if llm_client:
try:
Expand Down
20 changes: 16 additions & 4 deletions cais/components/dataset_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@
Input: (dataset_path, Transformation Spec JSON).
Output: A SINGLE Python script as text that:
- imports only: json, os, pandas as pd, numpy as np
- loads the dataset from dataset_path (infer CSV/Parquet by extension)
- NEVER RE-DEFINE `__DATASET_PATH__` or `__CLEANED_PATH__`. They are already provided.
- loads the dataset from the path provided in global variable `__DATASET_PATH__`
- applies ONLY what the Spec asks for (row_filters, column_ops, method_constructs, etc.)
- keeps all original columns unless Spec explicitly drops them
- creates any new columns explicitly; suffix where needed; no silent overwrite
- produces a dataframe named clean_df
- writes:
- cleaned_df.csv (same directory as dataset_path)
- clean_df.csv to the path provided in global variable `__CLEANED_PATH__`
- preprocessing_manifest.json (the Spec actually executed)
- derived_columns.json (list of new columns with one-line descriptions)
- prints a concise, human-readable summary report to stdout
Expand Down Expand Up @@ -181,7 +182,15 @@ def _run_script_text(script: str, dataset_path: str, cleaned_path: str) -> Tuple
with contextlib.redirect_stdout(stdout_io), contextlib.redirect_stderr(stderr_io):
# Provide dataset_path as a global the script can read (it should anyway use the passed JSON)
gbls["__DATASET_PATH__"] = dataset_path
script=script.replace('cleaned_df.csv', cleaned_path)
gbls["__CLEANED_PATH__"] = cleaned_path

# We restore the replacements but use json.dumps for safe quoting on Windows
# This handles models that hardcode the path despite instructions
for placeholder in ["cleaned_df.csv", "clean_df.csv", "manifest.json", "derived_columns.json"]:
if placeholder in script:
script = script.replace(f'"{placeholder}"', json.dumps(cleaned_path if "csv" in placeholder else placeholder))
script = script.replace(f"'{placeholder}'", json.dumps(cleaned_path if "csv" in placeholder else placeholder))

exec(script, gbls, lcls)
except Exception as e:
tb = traceback.format_exc()
Expand Down Expand Up @@ -246,7 +255,10 @@ def run_cleaning_stage(dataset_path: str,
"""
llm = get_llm_client()

cleaned_path = os.path.join(os.path.dirname(os.path.abspath(dataset_path)) or ".", f"{dataset_path.split('/')[-1][:-4]}_cleaned_{os.getpid()}.csv")
dataset_path = dataset_path.replace("\\", "/")
base_name = os.path.basename(dataset_path)
file_stem = os.path.splitext(base_name)[0]
cleaned_path = os.path.join(os.path.dirname(os.path.abspath(dataset_path)) or ".", f"{file_stem}_cleaned_{os.getpid()}.csv").replace("\\", "/")

# 1) PLAN
method = causal_method or variables.get("method") or ""
Expand Down
Loading