diff --git a/docs/ae-eval-s5.1-silent-issue-detection.md b/docs/ae-eval-s5.1-silent-issue-detection.md index da7d571c..aba371d5 100644 --- a/docs/ae-eval-s5.1-silent-issue-detection.md +++ b/docs/ae-eval-s5.1-silent-issue-detection.md @@ -1,61 +1,79 @@ # Eval: Silent Issue Detection -⏳ **Estimated Completion Time**: ~5 hours (if running everything from scratch) +⏳ **Estimated Completion Time**: ~30 minutes ## 🎯 Goal -In our evaluation, **TrainCheck detects 18 silent issues**. Your goal is to reproduce and validate: +TrainCheck detects **18 real-world silent issues** in our evaluation. Your goal in this artifact evaluation is to **verify detection for the subset of issues that are currently AE-supported** (see [bug table](#-bug-summary-table) below). -1. ✅ TrainCheck successfully detects each issue. -2. ⏱️ Detection occurs within **at most 1 iteration** of the issue being triggered. -3. 🔍 The reported invariant violations are **close to the root cause**. +For each supported bug, you should confirm: -To accomplish this, a full evaluation involves: -1. **Inferring invariants** from clean PyTorch example pipelines. -2. **Running each buggy pipeline** to produce a trace. -3. **Checking invariants** against these traces. -4. **Manually verifying** that each violation corresponds to a silent issue, is timely, and aligns with the root cause. +✅ **TrainCheck successfully detects the issue** by reporting one or more invariant violations on the provided trace. -## 🛠️ Evaluation Adjustments for Reproducibility +The artifact provides all necessary resources to automate this confirmation. +Additional insights—such as when the issue is triggered and how the violation aligns with the root cause—can be explored by examining the scripts, logs, or violation reports, though they are not required for core validation. -To ease your evaluation effort, we’ve made the following simplifications: +## 📂 Resources Provided -### 1. 🧪 Pre-provided Bug-Detecting Invariants -We provide a curated set of invariants that are known to catch the issues. -➡️ This eliminates the need to manually infer and filter a large number of candidate invariants. -(You may optionally rerun inference yourself to verify reproducibility.) +All files are located in the [`TrainCheck-Evaluation-Workloads`](https://github.com/OrderLab/TrainCheck-Evaluation-Workloads) repository. -### 2. 📦 Pre-collected Buggy Traces -Some bugs require: -- Complex library setups (e.g., DS-1801 requires source builds of HuggingFace's own fork of `Megatron-DeepSpeed` and a DeepSpeed installation that needs you to manually modify a few lines of code.) -- Large datasets (~100 GiB+) +| Resource | Description | +|---------|-------------| +| **Curated Invariants** | Small set of known-effective invariants per bug. | +| **Pre-collected Traces** | Captured execution traces from the buggy pipelines. | +| **Silent Issue Reproduction Scripts and Descriptions** | https://github.com/OrderLab/TrainCheck-Evaluation-Workloads/tree/main/silent-issue-detection/bug-reprod-scripts | -To avoid these barriers: -- We **pre-collected traces** for all 18 bugs. -- You can run TrainCheck directly on these traces without reproducing the full training environments. +### 🐛 Silent Issue Summary Table -> Note: We have setup instructions in the `README.md` doc within each bug's folder. Please let us know if you are following these instructions to collect the traces yourself and have encountered any issues. +| **Bug ID** | **Failure Location** | **AE?** | **AE Limitation (if any)** | +|---------------------------|----------------------|--------|------------------------------------------------------------------| +| `baichuan2-86` | HW/Driver | ✅ Yes | Reuses pytorch-84803 trace | +| `deepspeed-1801` | Framework | ✅ Yes | | +| `deepspeed-5794` | Framework | ❌ No | Invariant relation still under evaluation | +| `lightning-thunder-725` | Framework | ✅ Yes | | +| `mmpretrain-702` | Framework | ✅ Yes | | +| `pytorch-51800` | Framework | ✅ Yes | | +| `pytorch-84803` | HW/Driver | ✅ Yes | | +| `pytorch-96600` | HW/Driver | ✅ Yes | Reuses pytorch-84803 trace | +| `pytorch-104336` | Framework | ✅ Yes | | +| `pytorch-115607` | Compiler | ✅ Yes | | +| `pytorch-forum-84911` | User Code | ✅ Yes | | +| `stackoverflow-60335387` | User Code | ✅ Yes | | +| `stackoverflow-67180955` | Framework | ❌ No | Requires older Python version no longer supported | +| `transformers-17877` | Framework | ✅ Yes | | +| `transformers-23723` | Framework | ✅ Yes | | +| `transformers-33844` | Framework | ✅ Yes | | +| `transformers-34204` | Framework | ❌ No | Invariant support still in progress | +| `x-jxmnop-ddp-out-of-sync`| User Code | ✅ Yes | | -### 3. ⏱️ Early Bug Triggering + Ground Truth Iterations -We modified the buggy scripts to **trigger the silent issue as early as possible** and documented the **exact iteration** when each bug manifests. -You can verify the provided iteration number by inspecting the buggy code or logs as desired. +We currently support **15 out of 18 bugs** for artifact evaluation. +You have already detected `pytorch-forum-84911` in our 5-min tutorial. You will need to detect the rest of the 14 bugs. -## 📂 Resources & Scripts +Bugs not included in this AE release typically depend on: +- Unsupported or unstable library versions +- Very old Python environments +- Invariant support still in development -> Files described below are all in the [TrainCheck-Evaluation-Workloads](https://github.com/OrderLab/TrainCheck-Evaluation-Workloads/) repo. +Additionally, a few bugs stem from very specific issues such as faulty hardware, which are inherently difficult to reproduce. +For such cases—and for bugs that share the same root cause—we may provide a **shared/simulated trace** and a **shared invariant** that is reused across multiple bug IDs. -- **Automation Scripts** - Scripts for running TrainCheck to check invariants on buggy traces. +## 🧪 Reproducing Silent Issue Detection -- **Pre-collected Traces** - Traces collected from buggy pipelines to avoid complex reproduction steps. +> All steps described below assumes you are already in the `TrainCheck-Evaluation-Workloads` repo. If not, clone the repository and go to it. +> ```bash +> git clone https://github.com/OrderLab/TrainCheck-Evaluation-Workloads.git +> cd TrainCheck-Evaluation-Workloads +> ``` -- **Curated Invariants List** - A small, hand-picked set of invariants that are known to detect each bug effectively. +1. Make sure you have a working TrainCheck installation by following [TrainCheck Installation Guide](./installation-guide.md). -- ***[Optional]* Reproduction Scripts & Environment Setup** - Provided for each bug (in its respective folder) in case you want to collect your own trace. - This includes environment installation instructions and original buggy training scripts. +2. Execute `ae_detection.py` to automatically apply invariants to the pre-collected trace. This script generates results into a folder named `checker_output`. -- ***[Optional]* Example Pipelines for Invariant Inference** - Clean training pipelines used for inferring relevant invariants prior to checking. \ No newline at end of file +3. Compare the detection result folder with our claimed checker results, to verify that the checking process makes sense. + ```bash + diff -r checker_output reference_checker_output/ + ``` + +## Expected Results + +The `diff -r` command should return without any output. \ No newline at end of file diff --git a/docs/ae-eval-s5.4-fp-rate.md b/docs/ae-eval-s5.4-fp-rate.md index 072646fe..37fc07c6 100644 --- a/docs/ae-eval-s5.4-fp-rate.md +++ b/docs/ae-eval-s5.4-fp-rate.md @@ -91,7 +91,7 @@ You should verify that the false positive rates are similar or lower. Since the In our run of the script, we obtained the following results: ```csv setup,fp_rate -1-input,0.0387 -4-input,0.0143 -6-input,0.0119 +1-input,0.039 +4-input,0.021 +6-input,0.015 ``` diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index 83b876ff..0bb9b81e 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -62,12 +62,13 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]): keep_this_func = False for func_call_id in func_call_ids: event = trace.query_func_call_event(func_call_id) - if ( - "meta_vars.step" not in event.pre_record - or event.pre_record["meta_vars.step"] is None - or "args" not in event.pre_record - ): + if "args" not in event.pre_record: continue + + if "meta_vars.step" not in event.pre_record: + # assumed to be in the initialization phase + event.pre_record["meta_vars.step"] = -1 + keep_this_func = True process_id = event.pre_record["process_id"] thread_id = event.pre_record["thread_id"] @@ -438,7 +439,7 @@ def static_check_all( return CheckerResult( trace=[event1, event2], invariant=inv, - check_passed=True, + check_passed=False, triggered=True, ) diff --git a/traincheck/invariant/base_cls.py b/traincheck/invariant/base_cls.py index 9e79c250..6ea4c53b 100644 --- a/traincheck/invariant/base_cls.py +++ b/traincheck/invariant/base_cls.py @@ -794,7 +794,7 @@ def __init__( prop_name: str, prop_dtype: type | None, _type: PT, - additional_path: list[str] | None, + additional_path: tuple[str] | None, values: set | None, ): """A class to represent a single clause in a precondition. A clause is a property that should hold for the hypothesis to be valid. @@ -837,14 +837,14 @@ def __str__(self) -> str: return f"Prop: {self.prop_name}, Type: {self.type}" def to_dict(self) -> dict: - clause_dict: dict[str, str | list] = { + clause_dict: dict[str, str | tuple] = { "type": self.type.value, "prop_name": self.prop_name, "additional_path": self.additional_path if self.additional_path else "None", "prop_dtype": self.prop_dtype.__name__ if self.prop_dtype else "None", } if self.type in [PT.CONSTANT]: - clause_dict["values"] = list(self.values) + clause_dict["values"] = tuple(self.values) return clause_dict @staticmethod @@ -853,7 +853,7 @@ def from_dict(clause_dict: dict) -> PreconditionClause: _type = PT(clause_dict["type"]) prop_dtype = eval(clause_dict["prop_dtype"]) additional_path = ( - clause_dict["additional_path"] + tuple(clause_dict["additional_path"]) if clause_dict["additional_path"] != "None" else None ) diff --git a/traincheck/invariant/cover_relation.py b/traincheck/invariant/cover_relation.py index eba11132..d8c2fac0 100644 --- a/traincheck/invariant/cover_relation.py +++ b/traincheck/invariant/cover_relation.py @@ -155,7 +155,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: listed_events.items(), ascii=True, leave=True, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} - for funcA, funcB in tqdm( + for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, leave=True, @@ -163,17 +163,17 @@ def generate_hypothesis(trace) -> list[Hypothesis]: total=len(function_pool) ** 2, ): if check_same_level( - funcA, - funcB, + func_A, + func_B, process_id, thread_id, function_id_map, function_times, ): - if funcA not in same_level_func[(process_id, thread_id)]: - same_level_func[(process_id, thread_id)][funcA] = [] - same_level_func[(process_id, thread_id)][funcA].append(funcB) - valid_relations[(funcA, funcB)] = True + if func_A not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][func_A] = [] + same_level_func[(process_id, thread_id)][func_A].append(func_B) + valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations print("End same level checking") @@ -209,10 +209,21 @@ def generate_hypothesis(trace) -> list[Hypothesis]: desc="Function Pair", ): - if func_A not in same_level_func[(process_id, thread_id)]: + if func_B not in same_level_func[(process_id, thread_id)]: continue - if func_B not in same_level_func[(process_id, thread_id)][func_A]: + if func_A not in same_level_func[(process_id, thread_id)][func_B]: + # all B invocations are negative examples + for event in events_list: + if ( + event["type"] == "function_call (pre)" + and event["function"] == func_B + ): + example = Example() + example.add_group(EXP_GROUP_NAME, [event]) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(example) continue flag_A = None @@ -298,7 +309,7 @@ def collect_examples(trace, hypothesis): listed_events.items(), ascii=True, leave=True, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} - for funcA, funcB in tqdm( + for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, leave=True, @@ -306,17 +317,17 @@ def collect_examples(trace, hypothesis): total=len(function_pool) ** 2, ): if check_same_level( - funcA, - funcB, + func_A, + func_B, process_id, thread_id, function_id_map, function_times, ): - if funcA not in same_level_func[(process_id, thread_id)]: - same_level_func[(process_id, thread_id)][funcA] = [] - same_level_func[(process_id, thread_id)][funcA].append(funcB) - valid_relations[(funcA, funcB)] = True + if func_A not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][func_A] = [] + same_level_func[(process_id, thread_id)][func_A].append(func_B) + valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations print("End same level checking") @@ -343,21 +354,30 @@ def collect_examples(trace, hypothesis): print("Starting collecting iteration...") for i in tqdm(range(invariant_length - 1)): - func_A = inv.params[i] - func_B = inv.params[i + 1] + param_A = inv.params[i] + param_B = inv.params[i + 1] - assert isinstance(func_A, APIParam) and isinstance( - func_B, APIParam + assert isinstance(param_A, APIParam) and isinstance( + param_B, APIParam ), "Invariant parameters should be string." + func_A = param_A.api_full_name + func_B = param_B.api_full_name for (process_id, thread_id), events_list in listed_events.items(): - funcA = func_A.api_full_name - funcB = func_B.api_full_name - if funcA not in same_level_func[(process_id, thread_id)]: + if func_B not in same_level_func[(process_id, thread_id)]: continue - if funcB not in same_level_func[(process_id, thread_id)][funcA]: + if func_A not in same_level_func[(process_id, thread_id)][func_B]: + # all B invocations are negative examples + for event in events_list: + if ( + event["type"] == "function_call (pre)" + and event["function"] == func_B + ): + example = Example() + example.add_group(EXP_GROUP_NAME, [event]) + hypothesis.negative_examples.add_example(example) continue # check @@ -418,13 +438,13 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: ] = {} # relation_pool contains all binary relations classified by GroupedPreconditions (key) for hypothesis in all_hypotheses: - param0 = hypothesis.invariant.params[0] - param1 = hypothesis.invariant.params[1] + param_A = hypothesis.invariant.params[0] + param_B = hypothesis.invariant.params[1] - assert isinstance(param0, APIParam) and isinstance(param1, APIParam) + assert isinstance(param_A, APIParam) and isinstance(param_B, APIParam) if hypothesis.invariant.precondition not in relation_pool: relation_pool[hypothesis.invariant.precondition] = [] - relation_pool[hypothesis.invariant.precondition].append((param0, param1)) + relation_pool[hypothesis.invariant.precondition].append((param_A, param_B)) merged_relations: Dict[GroupedPreconditions | None, List[List[APIParam]]] = {} @@ -535,7 +555,7 @@ def static_check_all( listed_events.items(), ascii=True, leave=True, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} - for funcA, funcB in tqdm( + for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, leave=True, @@ -543,17 +563,17 @@ def static_check_all( total=len(function_pool) ** 2, ): if check_same_level( - funcA, - funcB, + func_A, + func_B, process_id, thread_id, function_id_map, function_times, ): - if funcA not in same_level_func[(process_id, thread_id)]: - same_level_func[(process_id, thread_id)][funcA] = [] - same_level_func[(process_id, thread_id)][funcA].append(funcB) - valid_relations[(funcA, funcB)] = True + if func_A not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][func_A] = [] + same_level_func[(process_id, thread_id)][func_A].append(func_B) + valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations print("End same level checking") @@ -585,21 +605,39 @@ def static_check_all( print("Starting checking iteration...") for i in tqdm(range(invariant_length - 1)): - func_A = inv.params[i] - func_B = inv.params[i + 1] + param_A = inv.params[i] + param_B = inv.params[i + 1] - assert isinstance(func_A, APIParam) and isinstance( - func_B, APIParam + assert isinstance(param_A, APIParam) and isinstance( + param_B, APIParam ), "Invariant parameters should be string." - for (process_id, thread_id), events_list in listed_events.items(): - funcA = func_A.api_full_name - funcB = func_B.api_full_name + func_A = param_A.api_full_name + func_B = param_B.api_full_name - if funcA not in same_level_func[(process_id, thread_id)]: + for (process_id, thread_id), events_list in listed_events.items(): + if func_B not in same_level_func[(process_id, thread_id)]: continue - if funcB not in same_level_func[(process_id, thread_id)][funcA]: + if func_A not in same_level_func[(process_id, thread_id)][func_B]: + # no A invoked, all B should be invalid + for event in events_list: + if ( + event["type"] == "function_call (pre)" + and event["function"] == func_B + ): + if not inv.precondition.verify( + [event], EXP_GROUP_NAME, trace + ): + continue + + inv_triggered = True + return CheckerResult( + trace=[event], + invariant=inv, + check_passed=False, + triggered=True, + ) continue # check @@ -608,10 +646,10 @@ def static_check_all( if event["type"] != "function_call (pre)": continue - if funcA == event["function"]: + if func_A == event["function"]: unmatched_A_exist = True - if funcB == event["function"]: + if func_B == event["function"]: if not inv.precondition.verify([event], EXP_GROUP_NAME, trace): continue diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 2fd86863..641f3e2a 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -24,20 +24,20 @@ def check_same_level( - funcA: str, - funcB: str, + func_A: str, + func_B: str, process_id: str, thread_id: str, function_id_map, function_times, ): - """Check if funcA and funcB are at the same level in the call stack. - By "same level", funcA and funcB are not always nested within each other (no caller-callee relationships). + """Check if func_A and func_B are at the same level in the call stack. + By "same level", func_A and func_B are not always nested within each other (no caller-callee relationships). The nested functions are filtered out in the preprocessing step. Args: - funcA (str): function name A - funcB (str): function name B + func_A (str): function name A + func_B (str): function name B process_id (str): process id thread_id (str): thread id function_id_map: a map from (process_id, thread_id) to function name to all function call ids of that function, @@ -46,20 +46,20 @@ def check_same_level( the times should be sorted by the time of the function call Returns: - bool: True if funcA and funcB are at the same level, False otherwise + bool: True if func_A and func_B are at the same level, False otherwise """ - if funcA == funcB: + if func_A == func_B: return False - if funcB not in function_id_map[(process_id, thread_id)]: + if func_B not in function_id_map[(process_id, thread_id)]: return False - if funcA not in function_id_map[(process_id, thread_id)]: + if func_A not in function_id_map[(process_id, thread_id)]: return False - for idA in function_id_map[(process_id, thread_id)][funcA]: - for idB in function_id_map[(process_id, thread_id)][funcB]: + for idA in function_id_map[(process_id, thread_id)][func_A]: + for idB in function_id_map[(process_id, thread_id)][func_B]: preA = function_times[(process_id, thread_id)][idA]["start"] postA = function_times[(process_id, thread_id)][idA]["end"] preB = function_times[(process_id, thread_id)][idB]["start"] @@ -296,7 +296,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: listed_events.items(), ascii=True, leave=True, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} - for funcA, funcB in tqdm( + for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, leave=True, @@ -304,17 +304,17 @@ def generate_hypothesis(trace) -> list[Hypothesis]: total=len(function_pool) ** 2, ): if check_same_level( - funcA, - funcB, + func_A, + func_B, process_id, thread_id, function_id_map, function_times, ): - if funcA not in same_level_func[(process_id, thread_id)]: - same_level_func[(process_id, thread_id)][funcA] = [] - same_level_func[(process_id, thread_id)][funcA].append(funcB) - valid_relations[(funcA, funcB)] = True + if func_A not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][func_A] = [] + same_level_func[(process_id, thread_id)][func_A].append(func_B) + valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations print("End same level checking") @@ -354,49 +354,49 @@ def generate_hypothesis(trace) -> list[Hypothesis]: continue if func_B not in same_level_func[(process_id, thread_id)][func_A]: + # no B is invoked in this process and thread. All A invocations are negative examples + for event in events_list: + if ( + event["type"] == "function_call (pre)" + and event["function"] == func_A + ): + example = Example() + example.add_group(EXP_GROUP_NAME, [event]) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(example) continue time_last_unmatched_A = None - pre_record_A = [] + last_pre_record_A = None + last_example = None + hypothesis = hypothesis_with_examples[(func_A, func_B)] for event in events_list: if event["type"] != "function_call (pre)": continue if func_A == event["function"]: - if time_last_unmatched_A is None: - time_last_unmatched_A = event["time"] - pre_record_A = [event] - continue + if time_last_unmatched_A: + # the last A has not been followed by a B, a negative example: + assert last_example + hypothesis.negative_examples.add_example(last_example) - assert len(pre_record_A) > 0 - example = Example() - example.add_group(EXP_GROUP_NAME, pre_record_A) - hypothesis_with_examples[ - (func_A, func_B) - ].negative_examples.add_example(example) - pre_record_A = [event] - continue + time_last_unmatched_A = event["time"] + last_pre_record_A = event + last_example = Example() + last_example.add_group(EXP_GROUP_NAME, [last_pre_record_A]) if func_B == event["function"]: - if time_last_unmatched_A is None: - continue - - assert len(pre_record_A) > 0 - example = Example() - example.add_group(EXP_GROUP_NAME, pre_record_A) - hypothesis_with_examples[ - (func_A, func_B) - ].positive_examples.add_example(example) - - time_last_unmatched_A = None + if time_last_unmatched_A: + assert ( + last_example + ), "Raising an alarm for an A without B, but A's record is None, likely a bug" + hypothesis.positive_examples.add_example(last_example) + time_last_unmatched_A = None if time_last_unmatched_A is not None: - time_last_unmatched_A = None - example = Example() - example.add_group(EXP_GROUP_NAME, pre_record_A) - hypothesis_with_examples[ - (func_A, func_B) - ].negative_examples.add_example(example) + assert last_example + hypothesis.negative_examples.add_example(last_example) print("End adding examples") @@ -462,7 +462,7 @@ def collect_examples(trace, hypothesis): listed_events.items(), ascii=True, leave=True, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} - for funcA, funcB in tqdm( + for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, leave=True, @@ -470,17 +470,17 @@ def collect_examples(trace, hypothesis): total=len(function_pool) ** 2, ): if check_same_level( - funcA, - funcB, + func_A, + func_B, process_id, thread_id, function_id_map, function_times, ): - if funcA not in same_level_func[(process_id, thread_id)]: - same_level_func[(process_id, thread_id)][funcA] = [] - same_level_func[(process_id, thread_id)][funcA].append(funcB) - valid_relations[(funcA, funcB)] = True + if func_A not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][func_A] = [] + same_level_func[(process_id, thread_id)][func_A].append(func_B) + valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations print("End same level checking") @@ -507,58 +507,61 @@ def collect_examples(trace, hypothesis): print("Starting collecting iteration...") for i in range(invariant_length - 1): - func_A = inv.params[i] - func_B = inv.params[i + 1] + param_A = inv.params[i] + param_B = inv.params[i + 1] - assert isinstance(func_A, APIParam) and isinstance( - func_B, APIParam + assert isinstance(param_A, APIParam) and isinstance( + param_B, APIParam ), "Invariant parameters should be string." + func_A = param_A.api_full_name + func_B = param_B.api_full_name for (process_id, thread_id), events_list in listed_events.items(): - funcA = func_A.api_full_name - funcB = func_B.api_full_name - if funcA not in same_level_func[(process_id, thread_id)]: + if func_A not in same_level_func[(process_id, thread_id)]: + # func_A is not invoked in this process and thread, no need to check continue - if funcB not in same_level_func[(process_id, thread_id)][funcA]: + if func_B not in same_level_func[(process_id, thread_id)][func_A]: + # no B is invoked in this process and thread. All A invocations are negative examples + for event in events_list: + if ( + event["type"] == "function_call (pre)" + and event["function"] == func_A + ): + last_example = Example() + last_example.add_group(EXP_GROUP_NAME, [event]) + hypothesis.negative_examples.add_example(last_example) continue time_last_unmatched_A = None - pre_record_A = [] + last_pre_record_A = None + last_example = None for event in events_list: if event["type"] != "function_call (pre)": continue if func_A == event["function"]: - if time_last_unmatched_A is None: - time_last_unmatched_A = event["time"] - pre_record_A = [event] - continue + if time_last_unmatched_A: + # the last A has not been followed by a B, a negative example: + assert last_example + hypothesis.negative_examples.add_example(last_example) - assert len(pre_record_A) > 0 - example = Example() - example.add_group(EXP_GROUP_NAME, pre_record_A) - hypothesis.negative_examples.add_example(example) - pre_record_A = [event] - continue + time_last_unmatched_A = event["time"] + last_pre_record_A = event + last_example = Example() + last_example.add_group(EXP_GROUP_NAME, [last_pre_record_A]) if func_B == event["function"]: - if time_last_unmatched_A is None: - continue - - assert len(pre_record_A) > 0 - example = Example() - example.add_group(EXP_GROUP_NAME, pre_record_A) - hypothesis.positive_examples.add_example(example) - - time_last_unmatched_A = None + if time_last_unmatched_A: + assert ( + last_example + ), "Raising an alarm for an A without B, but A's record is None, likely a bug" + hypothesis.positive_examples.add_example(last_example) + time_last_unmatched_A = None if time_last_unmatched_A is not None: - time_last_unmatched_A = None - example = Example() - example.add_group(EXP_GROUP_NAME, pre_record_A) - hypothesis.negative_examples.add_example(example) + hypothesis.negative_examples.add_example(last_example) @staticmethod def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: @@ -596,14 +599,14 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] ] = {} for hypothesis in all_hypotheses: - param0 = hypothesis.invariant.params[0] - param1 = hypothesis.invariant.params[1] + param_A = hypothesis.invariant.params[0] + param_B = hypothesis.invariant.params[1] - assert isinstance(param0, APIParam) and isinstance(param1, APIParam) + assert isinstance(param_A, APIParam) and isinstance(param_B, APIParam) if hypothesis.invariant.precondition not in relation_pool: relation_pool[hypothesis.invariant.precondition] = [] - relation_pool[hypothesis.invariant.precondition].append((param0, param1)) + relation_pool[hypothesis.invariant.precondition].append((param_A, param_B)) merged_relations: Dict[GroupedPreconditions | None, List[List[APIParam]]] = {} @@ -714,7 +717,7 @@ def static_check_all( listed_events.items(), ascii=True, leave=True, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} - for funcA, funcB in tqdm( + for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, leave=True, @@ -722,17 +725,17 @@ def static_check_all( total=len(function_pool) ** 2, ): if check_same_level( - funcA, - funcB, + func_A, + func_B, process_id, thread_id, function_id_map, function_times, ): - if funcA not in same_level_func[(process_id, thread_id)]: - same_level_func[(process_id, thread_id)][funcA] = [] - same_level_func[(process_id, thread_id)][funcA].append(funcB) - valid_relations[(funcA, funcB)] = True + if func_A not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][func_A] = [] + same_level_func[(process_id, thread_id)][func_A].append(func_B) + valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations print("End same level checking") @@ -764,21 +767,42 @@ def static_check_all( print("Starting checking iteration...") for i in range(invariant_length - 1): - func_A = inv.params[i] - func_B = inv.params[i + 1] + param_A = inv.params[i] + param_B = inv.params[i + 1] - assert isinstance(func_A, APIParam) and isinstance( - func_B, APIParam + assert isinstance(param_A, APIParam) and isinstance( + param_B, APIParam ), "Invariant parameters should be string." + func_A = param_A.api_full_name + func_B = param_B.api_full_name for (process_id, thread_id), events_list in listed_events.items(): - funcA = func_A.api_full_name - funcB = func_B.api_full_name - if funcA not in same_level_func[(process_id, thread_id)]: + if func_A not in same_level_func[(process_id, thread_id)]: + # func_A is not invoked in this process and thread, no need to check continue - if funcB not in same_level_func[(process_id, thread_id)][funcA]: + if func_B not in same_level_func[(process_id, thread_id)][func_A]: + # all A invocations in this process and thread are negative examples + # directly find the first A and return the result + for event in events_list: + if event["type"] != "function_call (pre)": + continue + + if func_A == event["function"]: + if not inv.precondition.verify( + [event], EXP_GROUP_NAME, trace + ): + continue + + inv_triggered = True + return CheckerResult( + trace=[event], + invariant=inv, + check_passed=False, + triggered=True, + ) + # if we have not returned in this branch, lets check the next process and thread continue # check @@ -789,7 +813,7 @@ def static_check_all( if event["type"] != "function_call (pre)": continue - if funcA == event["function"]: + if func_A == event["function"]: if not inv.precondition.verify([event], EXP_GROUP_NAME, trace): continue @@ -809,7 +833,7 @@ def static_check_all( check_passed=False, triggered=True, ) - if funcB == event["function"]: + if func_B == event["function"]: has_B_showup_for_last_A = True return CheckerResult( diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index 6970fd70..ad507dcb 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -136,10 +136,10 @@ def _find_local_clauses( clauses.append( PreconditionClause( - f"{context_manager_key}.{arg}", + f"{context_manager_key}", type(value), PT.CONSTANT, - [arg], + (arg,), {value}, ) ) @@ -181,10 +181,12 @@ def _merge_clauses( """ # step 1: Grouping the clauses by the target - clause_targets_and_exp_ids: dict[str, dict[PreconditionClause, list[int]]] = {} + clause_targets_and_exp_ids: dict[ + tuple[str, tuple[str] | None], dict[PreconditionClause, list[int]] + ] = {} for exp_id, clauses in enumerate(clauses_lists): for clause in clauses: - clause_target = clause.prop_name + clause_target = (clause.prop_name, clause.additional_path) if clause_target not in clause_targets_and_exp_ids: clause_targets_and_exp_ids[clause_target] = {clause: []} elif clause not in clause_targets_and_exp_ids[clause_target]: @@ -248,7 +250,11 @@ def _merge_clauses( and field_dtype is not bool ): consistent_clause = PreconditionClause( - target, field_dtype, PT.CONSISTENT, None, seen_unique_constant_values + target[0], + field_dtype, + PT.CONSISTENT, + target[1], + seen_unique_constant_values, ) merged_clauses_and_exp_ids[consistent_clause] = list( seen_unique_constant_exp_ids @@ -257,7 +263,7 @@ def _merge_clauses( # if the number of values seen is not too large, we should just keep the constant clauses for value in constant_value_to_exp_ids: constant_clause = PreconditionClause( - target, field_dtype, PT.CONSTANT, None, {value} + target[0], field_dtype, PT.CONSTANT, target[1], {value} ) merged_clauses_and_exp_ids[constant_clause] = list( constant_value_to_exp_ids[value]