diff --git a/mldaikon/config/config.py b/mldaikon/config/config.py index cd6e92cb..5070bae7 100644 --- a/mldaikon/config/config.py +++ b/mldaikon/config/config.py @@ -123,3 +123,15 @@ ENABLE_COND_DUMP = False INSTR_DESCRIPTORS = False + + +ALL_STAGE_NAMES = { + "init", + "training", + "evaluation", + "inference", + "testing", + "checkpointing", + "preprocessing", + "postprocessing", +} diff --git a/mldaikon/developer/annotations.py b/mldaikon/developer/annotations.py index 8f6c9afb..df2cb33e 100644 --- a/mldaikon/developer/annotations.py +++ b/mldaikon/developer/annotations.py @@ -1,4 +1,5 @@ import mldaikon.instrumentor.tracer as tracer +from mldaikon.config.config import ALL_STAGE_NAMES from mldaikon.instrumentor import meta_vars @@ -11,20 +12,9 @@ def annotate_stage(stage_name: str): Note that it is your responsibility to make sure this function is called on all threads that potentially can generate invariant candidates. """ - valid_stage_names = [ - "init", - "training", - "evaluation", - "inference", - "testing", - "checkpointing", - "preprocessing", - "postprocessing", - ] - assert ( - stage_name in valid_stage_names - ), f"Invalid stage name: {stage_name}, valid ones are {valid_stage_names}" + stage_name in ALL_STAGE_NAMES + ), f"Invalid stage name: {stage_name}, valid ones are {ALL_STAGE_NAMES}" meta_vars["stage"] = stage_name diff --git a/mldaikon/infer_engine.py b/mldaikon/infer_engine.py index e1fb6cf4..fe45023f 100644 --- a/mldaikon/infer_engine.py +++ b/mldaikon/infer_engine.py @@ -2,11 +2,18 @@ import datetime import json import logging +import os import random import time import mldaikon.config.config as config -from mldaikon.invariant.base_cls import FailedHypothesis, Invariant, Relation +from mldaikon.invariant.base_cls import ( + FailedHypothesis, + Hypothesis, + Invariant, + Relation, +) +from mldaikon.invariant.precondition import find_precondition from mldaikon.invariant.relation_pool import relation_pool from mldaikon.trace import MDNONEJSONEncoder, select_trace_implementation from mldaikon.utils import register_custom_excepthook @@ -46,6 +53,62 @@ def infer(self, disabled_relations: list[Relation]): ) return all_invs, all_failed_hypos + def infer_multi_trace(self, disabled_relations: list[Relation]): + hypotheses = self.generate_hypothesis(disabled_relations) + self.collect_examples(hypotheses) + invariants, failed_hypos = self.infer_precondition(hypotheses) + return invariants, failed_hypos + + def generate_hypothesis( + self, disabled_relations: list[Relation] + ) -> list[list[Hypothesis]]: + hypotheses = [] + for trace in self.traces: + for relation in relation_pool: + if disabled_relations is not None and relation in disabled_relations: + logger.info( + f"Skipping relation {relation.__name__} as it is disabled" + ) + continue + logger.info(f"Generating hypotheses for relation: {relation.__name__}") + hypotheses.append(relation.generate_hypothesis(trace)) + logger.info( + f"Found {len(hypotheses[-1])} hypotheses for relation: {relation.__name__}" + ) + return hypotheses + + def collect_examples(self, hypotheses: list[list[Hypothesis]]): + for i, trace in enumerate(self.traces): + for j, hypothesis in enumerate(hypotheses[i]): + if j == i: + # already collected examples for this hypothesis on the same trace that generated it + continue + hypothesis.invariant.relation.collect_examples(trace, hypothesis) + + def infer_precondition(self, hypotheses: list[list[Hypothesis]]): + all_hypotheses: list[Hypothesis] = [] + for trace_hypotheses in hypotheses: + for hypothesis in trace_hypotheses: + all_hypotheses.append(hypothesis) + + invariants = [] + failed_hypos = [] + for hypothesis in all_hypotheses: + hypothesis.invariant.num_positive_examples = len( + hypothesis.positive_examples + ) + hypothesis.invariant.num_negative_examples = len( + hypothesis.negative_examples + ) + precondition = find_precondition(hypothesis, self.traces) + if precondition is None: + failed_hypos.append(FailedHypothesis(hypothesis)) + else: + hypothesis.invariant.precondition = precondition + invariants.append(hypothesis.invariant) + + return invariants, failed_hypos + def save_invs(invs: list[Invariant], output_file: str): with open(output_file, "w") as f: @@ -69,9 +132,15 @@ def save_failed_hypos(failed_hypos: list[FailedHypothesis], output_file: str): "-t", "--traces", nargs="+", - required=True, + required=False, help="Traces files to infer invariants on", ) + parser.add_argument( + "-f", + "--trace-folders", + nargs="+", + help='Folders containing traces files to infer invariants on. Trace files should start with "trace_" or "proxy_log.json"', + ) parser.add_argument( "-d", "--debug", @@ -110,6 +179,14 @@ def save_failed_hypos(failed_hypos: list[FailedHypothesis], output_file: str): ) args = parser.parse_args() + # check if either traces or trace folders are provided + if args.traces is None and args.trace_folders is None: + # print help message if neither traces nor trace folders are provided + parser.print_help() + parser.error( + "Please provide either traces or trace folders to infer invariants" + ) + Trace, read_trace_file = select_trace_implementation(args.backend) if args.debug: @@ -138,14 +215,28 @@ def save_failed_hypos(failed_hypos: list[FailedHypothesis], output_file: str): config.PRECOND_SAMPLING_THRESHOLD = args.precond_sampling_threshold time_start = time.time() - logger.info("Reading traces from %s", "\n".join(args.traces)) - traces = [read_trace_file(args.traces)] + + traces = [] + if args.traces is not None: + logger.info("Reading traces from %s", "\n".join(args.traces)) + traces.append(read_trace_file(args.traces)) + if args.trace_folders is not None: + for trace_folder in args.trace_folders: + # file discovery + trace_files = [ + f"{trace_folder}/{file}" + for file in os.listdir(trace_folder) + if file.startswith("trace_") or file.startswith("proxy_log.json") + ] + logger.info("Reading traces from %s", "\n".join(trace_files)) + traces.append(read_trace_file(trace_files)) + time_end = time.time() logger.info(f"Traces read successfully in {time_end - time_start} seconds.") time_start = time.time() engine = InferEngine(traces) - invs, failed_hypos = engine.infer(disabled_relations=disabled_relations) + invs, failed_hypos = engine.infer_multi_trace(disabled_relations=disabled_relations) time_end = time.time() logger.info(f"Inference completed in {time_end - time_start} seconds.") diff --git a/mldaikon/instrumentor/tracer.py b/mldaikon/instrumentor/tracer.py index caccacf7..1d8a3177 100644 --- a/mldaikon/instrumentor/tracer.py +++ b/mldaikon/instrumentor/tracer.py @@ -327,7 +327,7 @@ def find_proxy_in_args(args): "func_call_id": func_call_id, "thread_id": thread_id, "process_id": process_id, - "meta_vars": get_meta_vars(), + "meta_vars": pre_meta_vars, "type": TraceLineType.FUNC_CALL_POST_EXCEPTION, "function": func_name, # "args": [f"{arg}" for arg in args], @@ -346,7 +346,7 @@ def find_proxy_in_args(args): pre_record.copy() ) # copy the pre_record (though we don't actually need to copy anything) post_record["type"] = TraceLineType.FUNC_CALL_POST - post_record["meta_vars"] = get_meta_vars() + post_record["meta_vars"] = pre_meta_vars result_to_dump = result @@ -1011,6 +1011,7 @@ def __init__(self, var, dump_tensor_hash: bool = True): self.param_versions = {} # type: ignore timestamp = datetime.datetime.now().timestamp() + curr_meta_vars = get_meta_vars() for param in self._get_state_copy(): attributes = param["attributes"] if dump_tensor_hash: @@ -1026,7 +1027,7 @@ def __init__(self, var, dump_tensor_hash: bool = True): "var_type": param["type"], "process_id": os.getpid(), "thread_id": threading.current_thread().ident, - "meta_vars": get_meta_vars(), + "meta_vars": curr_meta_vars, "type": TraceLineType.STATE_CHANGE, "attributes": attributes, "time": timestamp, @@ -1054,6 +1055,7 @@ def dump_sample(self): timestamp = datetime.datetime.now().timestamp() + curr_meta_vars = get_meta_vars() for param in self._get_state_copy(): dump_trace_VAR( { @@ -1061,7 +1063,7 @@ def dump_sample(self): "var_type": param["type"], # FIXME: hardcoding the type for now "process_id": os.getpid(), "thread_id": threading.current_thread().ident, - "meta_vars": get_meta_vars(), + "meta_vars": curr_meta_vars, "type": TraceLineType.STATE_CHANGE, "attributes": param["attributes"], "time": timestamp, diff --git a/mldaikon/invariant/DistinctArgumentRelation.py b/mldaikon/invariant/DistinctArgumentRelation.py index 26ee667e..79a98ef7 100644 --- a/mldaikon/invariant/DistinctArgumentRelation.py +++ b/mldaikon/invariant/DistinctArgumentRelation.py @@ -1,36 +1,34 @@ import logging from itertools import combinations -from typing import Any, Dict, List, Set, Tuple, Iterable +from typing import Any, Dict, Iterable, List, Set, Tuple from tqdm import tqdm -from mldaikon.invariant.base_cls import ( - Param, +from mldaikon.invariant.base_cls import ( # GroupedPreconditions, APIParam, CheckerResult, Example, ExampleList, FailedHypothesis, - # GroupedPreconditions, Hypothesis, Invariant, Relation, ) - from mldaikon.invariant.precondition import find_precondition from mldaikon.trace.trace import Trace EXP_GROUP_NAME = "distinct_arg" -MAX_FUNC_NUM_CONSECUTIVE_CALL = 6 -IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU +MAX_FUNC_NUM_CONSECUTIVE_CALL = 6 +IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU + def calculate_IOU(list1, list2): - set1 = set(list1) - set2 = set(list2) - intersection = set1.intersection(set2) - union = set1.union(set2) - iou = len(intersection) / len(union) if len(union) != 0 else 0 - return iou + set1 = set(list1) + set2 = set(list2) + intersection = set1.intersection(set2) + union = set1.union(set2) + iou = len(intersection) / len(union) if len(union) != 0 else 0 + return iou def get_func_names_to_deal_with(trace: Trace) -> List[str]: @@ -54,16 +52,20 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]: return list(function_pool) + def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]): - listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = ( - {} - ) + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} for func_name in function_pool.copy(): func_call_ids = trace.get_func_call_ids(func_name) keep_this_func = False for func_call_id in func_call_ids: event = trace.query_func_call_event(func_call_id) - if (event.pre_record["meta_vars.step"] is None or "args" not in event.pre_record): + if ( + event.pre_record["meta_vars.step"] is None + or "args" not in event.pre_record + ): continue keep_this_func = True process_id = event.pre_record["process_id"] @@ -73,16 +75,18 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]): listed_arguments[func_name] = {} listed_arguments[func_name][step] = {} listed_arguments[func_name][step][(process_id, thread_id)] = [] - + if step not in listed_arguments[func_name]: listed_arguments[func_name][step] = {} listed_arguments[func_name][step][(process_id, thread_id)] = [] - + if (process_id, thread_id) not in listed_arguments[func_name][step]: listed_arguments[func_name][step][(process_id, thread_id)] = [] - - listed_arguments[func_name][step][(process_id, thread_id)].append(event.pre_record) - + + listed_arguments[func_name][step][(process_id, thread_id)].append( + event.pre_record + ) + if not keep_this_func: function_pool.remove(func_name) @@ -90,7 +94,7 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]): def get_event_list(trace: Trace, function_pool: Iterable[str]): - listed_events: List[dict[str, Any]]= [] + listed_events: List[dict[str, Any]] = [] # for all func_ids, get their corresponding events for func_name in function_pool: func_call_ids = trace.get_func_call_ids(func_name) @@ -109,11 +113,16 @@ def get_event_list(trace: Trace, function_pool: Iterable[str]): return listed_events -def compare_argument(value1, value2, IOU_criteria = True): + +def compare_argument(value1, value2, IOU_criteria=True): if type(value1) != type(value2): return False if isinstance(value1, list): - if IOU_criteria and all(isinstance(item, int) for item in value1) and all(isinstance(item, int) for item in value2): + if ( + IOU_criteria + and all(isinstance(item, int) for item in value1) + and all(isinstance(item, int) for item in value2) + ): return calculate_IOU(value1, value2) >= IOU_THRESHHOLD if len(value1) != len(value2): return False @@ -134,6 +143,7 @@ def compare_argument(value1, value2, IOU_criteria = True): return abs(value1 - value2) < 1e-8 return value1 == value2 + def is_arguments_list_same(args1: list, args2: list): if len(args1) != len(args2): return False @@ -144,6 +154,7 @@ def is_arguments_list_same(args1: list, args2: list): return False return True + # class APIArgsParam(Param): # def __init__( # self, api_full_name: str, arg_name: str @@ -181,7 +192,9 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: # 1. Pre-process all the events print("Start preprocessing....") - listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = {} + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} function_pool: Set[Any] = set() function_pool = set(get_func_names_to_deal_with(trace)) @@ -205,9 +218,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: func_name: Hypothesis( invariant=Invariant( relation=DistinctArgumentRelation, - params=[ - APIParam(func_name) - ], + params=[APIParam(func_name)], precondition=None, text_description=f"{func_name} has distinct input arguments on difference PT for each step", ), @@ -226,29 +237,39 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: for PT_pair1, PT_pair2 in combinations(records.keys(), 2): for event1 in records[PT_pair1]: for event2 in records[PT_pair2]: - if(not is_arguments_list_same(event1["args"], event2["args"])): + if not is_arguments_list_same( + event1["args"], event2["args"] + ): flag = True pos = Example() pos.add_group(EXP_GROUP_NAME, [event1, event2]) - hypothesis_with_examples[func_name].positive_examples.add_example(pos) + hypothesis_with_examples[ + func_name + ].positive_examples.add_example(pos) else: neg = Example() neg.add_group(EXP_GROUP_NAME, [event1, event2]) - hypothesis_with_examples[func_name].negative_examples.add_example(neg) - + hypothesis_with_examples[ + func_name + ].negative_examples.add_example(neg) + for PT_pair in records.keys(): for event1, event2 in combinations(records[PT_pair], 2): - if(not is_arguments_list_same(event1["args"], event2["args"])): + if not is_arguments_list_same(event1["args"], event2["args"]): flag = True pos = Example() pos.add_group(EXP_GROUP_NAME, [event1, event2]) - hypothesis_with_examples[func_name].positive_examples.add_example(pos) + hypothesis_with_examples[ + func_name + ].positive_examples.add_example(pos) else: neg = Example() neg.add_group(EXP_GROUP_NAME, [event1, event2]) - hypothesis_with_examples[func_name].negative_examples.add_example(neg) - - if(not flag): + hypothesis_with_examples[ + func_name + ].negative_examples.add_example(neg) + + if not flag: hypothesis_with_examples.pop(func_name) print("End adding examples") @@ -261,13 +282,11 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: logger.debug( f"Finding Precondition for {hypo}: {hypothesis_with_examples[hypo].invariant.text_description}" ) - preconditions = find_precondition(hypothesis_with_examples[hypo], trace) + preconditions = find_precondition(hypothesis_with_examples[hypo], [trace]) logger.debug(f"Preconditions for {hypo}:\n{str(preconditions)}") if preconditions is not None: - hypothesis_with_examples[hypo].invariant.precondition = ( - preconditions - ) + hypothesis_with_examples[hypo].invariant.precondition = preconditions else: logger.debug(f"Precondition not found for {hypo}") failed_hypothesis.append( @@ -281,13 +300,10 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: print("End precondition inference") return ( - list( - [hypo.invariant for hypo in hypothesis_with_examples.values()] - ), + list([hypo.invariant for hypo in hypothesis_with_examples.values()]), failed_hypothesis, ) - @staticmethod def evaluate(value_group: list) -> bool: """Given a group of values, should return a boolean value @@ -317,13 +333,13 @@ def static_check_all( # 1. Pre-process all the events print("Start preprocessing....") - listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = {} + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} function_pool: Set[Any] = set() - func= inv.params[0] + func = inv.params[0] - assert isinstance( - func, APIParam - ), "Invariant parameters should be APIParam." + assert isinstance(func, APIParam), "Invariant parameters should be APIParam." func_name = func.api_full_name function_pool.add(func_name) @@ -352,20 +368,10 @@ def static_check_all( ) for step, records in listed_arguments[func_name].items(): - for PT_pair1, PT_pair2 in combinations(records.keys(), 2): - for event1 in records[PT_pair1]: - for event2 in records[PT_pair2]: - if(is_arguments_list_same(event1["args"], event2["args"])): - return CheckerResult( - trace=[event1, event2], - invariant=inv, - check_passed=True, - triggered=True, - ) - - for PT_pair in records.keys(): - for event1, event2 in combinations(records[PT_pair], 2): - if(is_arguments_list_same(event1["args"], event2["args"])): + for PT_pair1, PT_pair2 in combinations(records.keys(), 2): + for event1 in records[PT_pair1]: + for event2 in records[PT_pair2]: + if is_arguments_list_same(event1["args"], event2["args"]): return CheckerResult( trace=[event1, event2], invariant=inv, @@ -373,9 +379,23 @@ def static_check_all( triggered=True, ) + for PT_pair in records.keys(): + for event1, event2 in combinations(records[PT_pair], 2): + if is_arguments_list_same(event1["args"], event2["args"]): + return CheckerResult( + trace=[event1, event2], + invariant=inv, + check_passed=True, + triggered=True, + ) + return CheckerResult( trace=None, invariant=inv, check_passed=True, triggered=True, ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/DistinctArgumentRelation_refactor.py b/mldaikon/invariant/DistinctArgumentRelation_refactor.py new file mode 100644 index 00000000..c6407c80 --- /dev/null +++ b/mldaikon/invariant/DistinctArgumentRelation_refactor.py @@ -0,0 +1,456 @@ +import logging +import copy +from itertools import combinations +from typing import Any, Dict, Iterable, List, Set, Tuple + +from tqdm import tqdm + +from mldaikon.invariant.base_cls import ( # GroupedPreconditions, + APIParam, + CheckerResult, + Example, + ExampleList, + FailedHypothesis, + Hypothesis, + Invariant, + Relation, +) +from mldaikon.invariant.precondition import find_precondition +from mldaikon.trace.trace import Trace + +EXP_GROUP_NAME = "distinct_arg" +MAX_FUNC_NUM_CONSECUTIVE_CALL = 6 +IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU + + +def calculate_IOU(list1, list2): + set1 = set(list1) + set2 = set(list2) + intersection = set1.intersection(set2) + union = set1.union(set2) + iou = len(intersection) / len(union) if len(union) != 0 else 0 + return iou + + +def get_func_names_to_deal_with(trace: Trace) -> List[str]: + """Get all functions in the trace.""" + function_pool: Set[str] = set() + + # get all functions in the trace + all_func_names = trace.get_func_names() + + # filtering 1: remove private functions + for func_name in all_func_names: + if "._" in func_name: + continue + function_pool.add(func_name) + + # filtering 2: remove functions that have consecutive calls less than FUNC_CALL_FILTERING_THRESHOLD + for func_name in function_pool.copy(): + max_num_consecutive_call = trace.get_max_num_consecutive_call_func(func_name) + if max_num_consecutive_call > MAX_FUNC_NUM_CONSECUTIVE_CALL: + function_pool.remove(func_name) + + return list(function_pool) + + +def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]): + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} + for func_name in function_pool.copy(): + func_call_ids = trace.get_func_call_ids(func_name) + keep_this_func = False + for func_call_id in func_call_ids: + event = trace.query_func_call_event(func_call_id) + if ( + event.pre_record["meta_vars.step"] is None + or "args" not in event.pre_record + ): + continue + keep_this_func = True + process_id = event.pre_record["process_id"] + thread_id = event.pre_record["thread_id"] + step = event.pre_record["meta_vars.step"] + if func_name not in listed_arguments: + listed_arguments[func_name] = {} + listed_arguments[func_name][step] = {} + listed_arguments[func_name][step][(process_id, thread_id)] = [] + + if step not in listed_arguments[func_name]: + listed_arguments[func_name][step] = {} + listed_arguments[func_name][step][(process_id, thread_id)] = [] + + if (process_id, thread_id) not in listed_arguments[func_name][step]: + listed_arguments[func_name][step][(process_id, thread_id)] = [] + + listed_arguments[func_name][step][(process_id, thread_id)].append( + event.pre_record + ) + + if not keep_this_func: + function_pool.remove(func_name) + + return function_pool, listed_arguments + + +def get_event_list(trace: Trace, function_pool: Iterable[str]): + listed_events: List[dict[str, Any]] = [] + # for all func_ids, get their corresponding events + for func_name in function_pool: + func_call_ids = trace.get_func_call_ids(func_name) + for func_call_id in func_call_ids: + event = trace.query_func_call_event(func_call_id) + listed_events.extend( + # [event.pre_record, event.post_record] + [event.pre_record] + ) + + # sort the listed_events + # for (process_id, thread_id), events_list in listed_events.items(): + # listed_events[(process_id, thread_id)] = sorted( + # events_list, key=lambda x: x["time"] + # ) + + return listed_events + + +def compare_argument(value1, value2, IOU_criteria=True): + if type(value1) != type(value2): + return False + if isinstance(value1, list): + if ( + IOU_criteria + and all(isinstance(item, int) for item in value1) + and all(isinstance(item, int) for item in value2) + ): + return calculate_IOU(value1, value2) >= IOU_THRESHHOLD + if len(value1) != len(value2): + return False + for idx, val in enumerate(value1): + if not compare_argument(val, value2[idx]): + return False + return True + if isinstance(value1, dict): + if len(value1) != len(value2): + return False + for key in value1: + if key not in value2: + return False + if not compare_argument(value1[key], value2[key]): + return False + return True + if isinstance(value1, float): + return abs(value1 - value2) < 1e-8 + return value1 == value2 + + +def is_arguments_list_same(args1: list, args2: list): + if len(args1) != len(args2): + return False + for index in range(len(args1)): + arg1 = args1[index] + arg2 = args2[index] + if not compare_argument(arg1, arg2): + return False + return True + + +# class APIArgsParam(Param): +# def __init__( +# self, api_full_name: str, arg_name: str +# ): +# self.api_full_name = api_full_name +# self.arg_name = arg_name + +# def __eq__(self, other): +# if isinstance(other, APIArgsParam): +# return self.api_full_name == other.api_full_name and self.arg_name == other.arg_name +# return False + +# def __hash__(self): +# return hash(self.api_full_name + self.arg_name) + +# def __str__(self): +# return f"{self.api_full_name} {self.arg_name}" + +# def __repr__(self): +# return self.__str__() + + +class DistinctArgumentRelation(Relation): + """FunctionCoverRelation is a relation that checks if one function covers another function. + + say function A and function B are two functions in the trace, we say function A covers function B when + every time function B is called, a function A invocation exists before it. + """ + + @staticmethod + def generate_hypothesis(trace) -> list[Hypothesis]: + """Generate hypothesis for the DistinctArgumentRelation on trace.""" + # 1. Pre-process all the events + print("Start preprocessing....") + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} + function_pool: Set[Any] = set() + + function_pool = set(get_func_names_to_deal_with(trace)) + + function_pool, listed_arguments = get_event_data_per_function_per_step( + trace, function_pool + ) + print("End preprocessing") + + # If there is no filtered function, return [], [] + if not function_pool: + return [], [] + + # This is just for test. + # function_pool = set() + # function_pool.add("torch.nn.init.normal_") + + # 2. Generating hypothesis + print("Start generating hypo...") + hypothesis_with_examples = { + func_name: Hypothesis( + invariant=Invariant( + relation=DistinctArgumentRelation, + params=[APIParam(func_name)], + precondition=None, + text_description=f"{func_name} has distinct input arguments on difference PT for each step", + ), + positive_examples=ExampleList({EXP_GROUP_NAME}), + negative_examples=ExampleList({EXP_GROUP_NAME}), + ) + for func_name in function_pool + } + print("End generating hypo") + + # 3. Add positive and negative examples + print("Start adding examples...") + for func_name in tqdm(function_pool): + flag = False + for step, records in listed_arguments[func_name].items(): + for PT_pair1, PT_pair2 in combinations(records.keys(), 2): + for event1 in records[PT_pair1]: + for event2 in records[PT_pair2]: + if not is_arguments_list_same( + event1["args"], event2["args"] + ): + flag = True + pos = Example() + pos.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis_with_examples[ + func_name + ].positive_examples.add_example(pos) + else: + neg = Example() + neg.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis_with_examples[ + func_name + ].negative_examples.add_example(neg) + + for PT_pair in records.keys(): + for event1, event2 in combinations(records[PT_pair], 2): + if not is_arguments_list_same(event1["args"], event2["args"]): + flag = True + pos = Example() + pos.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis_with_examples[ + func_name + ].positive_examples.add_example(pos) + else: + neg = Example() + neg.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis_with_examples[ + func_name + ].negative_examples.add_example(neg) + + if not flag: + hypothesis_with_examples.pop(func_name) + + print("End adding examples") + + return list(hypothesis_with_examples.values()) + + @staticmethod + def collect_examples(trace, hypothesis): + """Generate examples for a hypothesis on trace.""" + inv = hypothesis.invariant + + # 1. Pre-process all the events + print("Start preprocessing....") + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} + function_pool: Set[Any] = set() + func = inv.params[0] + + assert isinstance(func, APIParam), "Invariant parameters should be APIParam." + + func_name = func.api_full_name + function_pool.add(func_name) + + function_pool, listed_arguments = get_event_data_per_function_per_step( + trace, function_pool + ) + + print("End preprocessing") + + if not function_pool: + return + + for step, records in listed_arguments[func_name].items(): + for PT_pair1, PT_pair2 in combinations(records.keys(), 2): + for event1 in records[PT_pair1]: + for event2 in records[PT_pair2]: + if not is_arguments_list_same( + event1["args"], event2["args"] + ): + pos = Example() + pos.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis.positive_examples.add_example(pos) + else: + neg = Example() + neg.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis.negative_examples.add_example(neg) + + for PT_pair in records.keys(): + for event1, event2 in combinations(records[PT_pair], 2): + if not is_arguments_list_same(event1["args"], event2["args"]): + pos = Example() + pos.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis.positive_examples.add_example(pos) + else: + neg = Example() + neg.add_group(EXP_GROUP_NAME, [event1, event2]) + hypothesis.negative_examples.add_example(neg) + + + @staticmethod + def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: + """Infer Invariants for the FunctionCoverRelation.""" + all_hypotheses = DistinctArgumentRelation.generate_hypothesis(trace) + + # for hypothesis in all_hypotheses: + # DistinctArgumentRelation.collect_examples(trace, hypothesis) + + # 4. Precondition inference + print("Start precondition inference...") + failed_hypothesis = [] + for hypothesis in all_hypotheses.copy(): + preconditions = find_precondition( + hypothesis, trace + ) + if preconditions is not None: + hypothesis.invariant.precondition = ( + preconditions + ) + else: + failed_hypothesis.append( + FailedHypothesis(hypothesis) + ) + all_hypotheses.remove(hypothesis) + print("End precondition inference") + + return ( + list([hypo.invariant for hypo in all_hypotheses]), + failed_hypothesis, + ) + + @staticmethod + def evaluate(value_group: list) -> bool: + """Given a group of values, should return a boolean value + indicating whether the relation holds or not. + + args: + value_group: list + A list of values to evaluate the relation on. The length of the list + should be equal to the number of variables in the relation. + """ + return True + + @staticmethod + def static_check_all( + trace: Trace, inv: Invariant, check_relation_first: bool + ) -> CheckerResult: + """Given a trace and an invariant, should return a boolean value + indicating whether the invariant holds on the trace. + + args: + trace: Trace + A trace to check the invariant on. + inv: Invariant + The invariant to check on the trace. + """ + assert inv.precondition is not None, "Invariant should have a precondition." + + # 1. Pre-process all the events + print("Start preprocessing....") + listed_arguments: Dict[ + str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] + ] = {} + function_pool: Set[Any] = set() + func = inv.params[0] + + assert isinstance(func, APIParam), "Invariant parameters should be APIParam." + + func_name = func.api_full_name + function_pool.add(func_name) + + function_pool, listed_arguments = get_event_data_per_function_per_step( + trace, function_pool + ) + + if not function_pool: + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + events_list = get_event_list(trace, function_pool) + print("End preprocessing") + + if not inv.precondition.verify([events_list], EXP_GROUP_NAME): + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + for step, records in listed_arguments[func_name].items(): + for PT_pair1, PT_pair2 in combinations(records.keys(), 2): + for event1 in records[PT_pair1]: + for event2 in records[PT_pair2]: + if is_arguments_list_same(event1["args"], event2["args"]): + return CheckerResult( + trace=[event1, event2], + invariant=inv, + check_passed=True, + triggered=True, + ) + + for PT_pair in records.keys(): + for event1, event2 in combinations(records[PT_pair], 2): + if is_arguments_list_same(event1["args"], event2["args"]): + return CheckerResult( + trace=[event1, event2], + invariant=inv, + check_passed=True, + triggered=True, + ) + + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=True, + ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/base_cls.py b/mldaikon/invariant/base_cls.py index dccf9b1d..6eeab08d 100644 --- a/mldaikon/invariant/base_cls.py +++ b/mldaikon/invariant/base_cls.py @@ -1068,6 +1068,16 @@ def add_stage_info(self, valid_stages: set[str]): values=valid_stages, ) + invalid_stages = set(config.ALL_STAGE_NAMES) - valid_stages + + inverted_state_clause = PreconditionClause( + prop_name=STAGE_KEY, + prop_dtype=str, + _type=PT.CONSTANT, + additional_path=None, + values=invalid_stages, + ) + # add the stage clause to all the preconditions, if UNCONDITIONAL, then swap it with the stage clause for group_name, preconditions in self.grouped_preconditions.items(): if preconditions.is_unconditional(): @@ -1075,11 +1085,12 @@ def add_stage_info(self, valid_stages: set[str]): [Precondition([stage_clause])], inverted=False ) else: - assert ( - not preconditions.inverted - ), "Adding clause to inverted preconditions is not supported yet" - for precondition in preconditions: - precondition.add_clause(stage_clause) + if not preconditions.inverted: + for precondition in preconditions: + precondition.add_clause(stage_clause) + else: + for precondition in preconditions: + precondition.add_clause(inverted_state_clause) def __eq__(self, other) -> bool: if not isinstance(other, GroupedPreconditions): @@ -1158,9 +1169,10 @@ def to_dict(self, _dumping_for_failed_cases=False) -> dict: "num_negative_examples": self.num_negative_examples, } else: - assert ( - self.precondition is None - ), "Precondition should be None for failed cases" + # assert ( + # self.precondition is None + # ), "Precondition should be None for failed cases" + # HACK: NEED TO CHECK ALL FAILED INVARIANTS THAT HAVE PRECONDITIONS return { "text_description": self.text_description, "relation": self.relation.__name__, @@ -1428,6 +1440,26 @@ def to_dict(self): class Relation(abc.ABC): + @staticmethod + def generate_hypothesis(trace: Trace) -> list[Hypothesis]: + """Given a trace, should return a list of hypothesis with positive/negative examples collected on the current trace + args: + trace: Trace + A trace to generate the hypothesis on. + """ + raise NotImplementedError("generate_hypothesis method is not implemented yet.") + + @staticmethod + def collect_examples(trace: Trace, hypothesis: Hypothesis): + """Given a trace and a hypothesis, should collect positive and negative examples for the hypothesis. + args: + trace: Trace + A trace to collect the examples on. + hypothesis: Hypothesis + The hypothesis to collect the examples for. + """ + raise NotImplementedError("collect_examples method is not implemented yet.") + @staticmethod @abc.abstractmethod def infer(trace) -> tuple[list[Invariant], list[FailedHypothesis]]: @@ -1453,6 +1485,17 @@ def evaluate(value_group: list) -> bool: """ pass + @staticmethod + @abc.abstractmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + """Given a hypothesis, should return a list of keys to skip in the infer process. + + args: + hypothesis: Hypothesis + The hypothesis to get the keys to skip for. + """ + pass + @staticmethod def from_name(relation_name: str) -> Type[Relation]: """Given a relation name, should return the relation class. diff --git a/mldaikon/invariant/consistency_relation.py b/mldaikon/invariant/consistency_relation.py index ae6486f4..f3077d70 100644 --- a/mldaikon/invariant/consistency_relation.py +++ b/mldaikon/invariant/consistency_relation.py @@ -111,7 +111,43 @@ class ConsistencyRelation(Relation): @staticmethod def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: """Infer Invariants for the ConsistencyRelation.""" + logger = logging.getLogger(__name__) + + hypotheses = ConsistencyRelation.generate_hypothesis(trace) + + failed_hypothesis = [] + passed_hypothesis = [] + for hypo in hypotheses: + param1 = hypo.invariant.params[0] + param2 = hypo.invariant.params[1] + + assert isinstance(param1, VarTypeParam) and isinstance( + param2, VarTypeParam + ), "Invariant parameters should be VarTypeParam." + logger.debug(f"Finding Precondition for: {hypo.invariant.text_description}") + preconditions = find_precondition(hypo, [trace]) + logger.debug(f"Preconditions for {hypo}:\n{str(preconditions)}") + + if preconditions is not None: + hypo.invariant.precondition = preconditions + hypo.invariant.num_positive_examples = len(hypo.positive_examples) + hypo.invariant.num_negative_examples = len(hypo.negative_examples) + passed_hypothesis.append(hypo) + else: + logger.debug( + f"Precondition not found for {hypo.invariant.text_description}" + ) + failed_hypothesis.append(FailedHypothesis(hypo)) + + return ( + list([hypothesis.invariant for hypothesis in passed_hypothesis]), + failed_hypothesis, + ) + + @staticmethod + def generate_hypothesis(trace: Trace) -> list[Hypothesis]: + """Generate Hypothesis for the ConsistencyRelation.""" logger = logging.getLogger(__name__) ## 1. Pre-scanning: Collecting variable instances and their values from the trace @@ -119,7 +155,7 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: var_insts = trace.get_var_insts() if len(var_insts) == 0: logger.warning("No variables found in the trace.") - return [], [] + return [] def is_hypo_already_in_hypothesis(hypo: tuple, hypothesis: set) -> bool: return ( @@ -206,7 +242,6 @@ def skip_attrs_with_different_dtypes(attr1, attr2): # filtered_hypothesis = [("torch.nn.Parameter", "data", "torch.nn.Parameter", "data")] ## 4. Positive Examples and Negative Examples Collection - group_name = VAR_GROUP_NAME hypothesis_with_examples = { hypo: Hypothesis( invariant=Invariant( @@ -218,121 +253,99 @@ def skip_attrs_with_different_dtypes(attr1, attr2): precondition=None, text_description=f"Consistency Relation between {hypo[0]}.{hypo[1]} and {hypo[2]}.{hypo[3]}", ), - positive_examples=ExampleList({group_name}), - negative_examples=ExampleList({group_name}), + positive_examples=ExampleList({VAR_GROUP_NAME}), + negative_examples=ExampleList({VAR_GROUP_NAME}), ) for hypo in filtered_hypothesis } for hypo in hypothesis_with_examples: - var_type1 = hypo[0] - attr1 = hypo[1] - var_type2 = hypo[2] - attr2 = hypo[3] - - is_skipping_init_values = False - if skip_init_values(var_type1) or skip_init_values(var_type2): - is_skipping_init_values = True - - # collect all variables that have the same types as var_type1 and var_type2 - var_type1_vars = [ - var_inst for var_inst in var_insts if var_inst.var_type == var_type1 - ] - var_type2_vars = [ - var_inst for var_inst in var_insts if var_inst.var_type == var_type2 - ] - - for idx1, var_inst1 in enumerate( - tqdm(var_type1_vars, desc=f"Collecting Examples for Hypo: {hypo}") - ): - for idx2, var_inst2 in enumerate(var_type2_vars): - if var_type1 == var_type2 and attr1 == attr2 and idx1 >= idx2: - continue - if var_inst1 == var_inst2: - continue - for _, value1 in enumerate( - var_insts[var_inst1][attr1][int(is_skipping_init_values) :] - ): - for _, value2 in enumerate( - var_insts[var_inst2][attr2][int(is_skipping_init_values) :] - ): - saw_overlap = False - overlap = calc_liveness_overlap( - value1.liveness, value2.liveness - ) - if overlap > config.LIVENESS_OVERLAP_THRESHOLD: - if compare_with_fp_tolerance( - value1.value, - value2.value, - ): - hypothesis_with_examples[ - hypo - ].positive_examples.add_example( - Example( - { - group_name: [ - value1.traces[0], - value2.traces[0], - ] - } ## HACK to make preconditions inference work for `step` - ) - ) - else: - hypothesis_with_examples[ - hypo - ].negative_examples.add_example( - Example( - { - group_name: [ - value1.traces[0], - value2.traces[0], - ] - } ## HACK to make preconditions inference work for `step` - ) - ) - else: - if saw_overlap: - # there won't be any more overlap, so we can break - break + ConsistencyRelation.collect_examples(trace, hypothesis_with_examples[hypo]) - ## 5. Precondition Inference TODO: this can be abstracted into a separate function that takes a list of hypothesis and returns those with preconditions - failed_hypothesis = [] - passed_hypothesis = [] - for hypo in hypothesis_with_examples: - logger.debug( - f"Finding Precondition for {hypo}: {hypothesis_with_examples[hypo].invariant.text_description}" - ) - preconditions = find_precondition( - hypothesis_with_examples[hypo], - trace=trace, - keys_to_skip=[ - f"attributes.{hypo[1]}", - f"attributes.{hypo[3]}", - "attributes.data", - "attributes.grad", - ], - ) - logger.debug(f"Preconditions for {hypo}:\n{str(preconditions)}") + return list(hypothesis_with_examples.values()) - if preconditions is not None: - hypothesis_with_examples[hypo].invariant.precondition = preconditions - hypothesis_with_examples[hypo].invariant.num_positive_examples = len( - hypothesis_with_examples[hypo].positive_examples - ) - hypothesis_with_examples[hypo].invariant.num_negative_examples = len( - hypothesis_with_examples[hypo].negative_examples - ) - passed_hypothesis.append(hypothesis_with_examples[hypo]) - else: - logger.debug(f"Precondition not found for {hypo}") - failed_hypothesis.append( - FailedHypothesis(hypothesis_with_examples[hypo]) - ) + @staticmethod + def collect_examples(trace: Trace, hypothesis: Hypothesis): + """Collect Examples for the ConsistencyRelation. + The modification of the hypothesis is done in-place. + """ + inv = hypothesis.invariant + assert ( + inv.relation == ConsistencyRelation + ), "Invariant should be ConsistencyRelation." + assert len(inv.params) == 2, "Invariant should have exactly two parameters." - return ( - list([hypothesis.invariant for hypothesis in passed_hypothesis]), - failed_hypothesis, - ) + param1 = inv.params[0] + param2 = inv.params[1] + + assert isinstance(param1, VarTypeParam) and isinstance( + param2, VarTypeParam + ), "Invariant parameters should be VarTypeParam." + var_type1, attr1 = param1.var_type, param1.attr_name + var_type2, attr2 = param2.var_type, param2.attr_name + + is_skipping_init_values = False + if skip_init_values(var_type1) or skip_init_values(var_type2): + is_skipping_init_values = True + + var_insts = trace.get_var_insts() + + # collect all variables that have the same types as var_type1 and var_type2 + var_type1_vars = [ + var_inst for var_inst in var_insts if var_inst.var_type == var_type1 + ] + var_type2_vars = [ + var_inst for var_inst in var_insts if var_inst.var_type == var_type2 + ] + + for idx1, var_inst1 in enumerate( + tqdm(var_type1_vars, desc=f"Collecting Examples for Hypo: {hypothesis}") + ): + for idx2, var_inst2 in enumerate(var_type2_vars): + if var_type1 == var_type2 and attr1 == attr2 and idx1 >= idx2: + continue + if var_inst1 == var_inst2: + continue + for _, value1 in enumerate( + var_insts[var_inst1][attr1][int(is_skipping_init_values) :] + ): + for _, value2 in enumerate( + var_insts[var_inst2][attr2][int(is_skipping_init_values) :] + ): + saw_overlap = False + overlap = calc_liveness_overlap( + value1.liveness, value2.liveness + ) + if overlap > config.LIVENESS_OVERLAP_THRESHOLD: + if compare_with_fp_tolerance( + value1.value, + value2.value, + ): + hypothesis.positive_examples.add_example( + Example( + { + VAR_GROUP_NAME: [ + value1.traces[0], + value2.traces[0], + ] + } ## HACK to make preconditions inference work for `step` + ) + ) + else: + hypothesis.negative_examples.add_example( + Example( + { + VAR_GROUP_NAME: [ + value1.traces[0], + value2.traces[0], + ] + } ## HACK to make preconditions inference work for `step` + ) + ) + else: + if saw_overlap: + # there won't be any more overlap, so we can break + break @staticmethod def evaluate(value_group: list) -> bool: @@ -505,3 +518,24 @@ def static_check_all( check_passed=True, triggered=inv_triggered, ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + assert ( + len(hypothesis.invariant.params) == 2 + ), "Invariant should have exactly two parameters." + param1 = hypothesis.invariant.params[0] + param2 = hypothesis.invariant.params[1] + + assert isinstance(param1, VarTypeParam) and isinstance( + param2, VarTypeParam + ), "Invariant parameters should be VarTypeParam." + attr1 = param1.var_type, param1.attr_name + attr2 = param2.var_type, param2.attr_name + + return [ + f"attributes.{attr1}", + f"attributes.{attr2}", + "attributes.data", + "attributes.grad", + ] diff --git a/mldaikon/invariant/consistency_transient_vars.py b/mldaikon/invariant/consistency_transient_vars.py index f066ed0e..6829511b 100644 --- a/mldaikon/invariant/consistency_transient_vars.py +++ b/mldaikon/invariant/consistency_transient_vars.py @@ -28,6 +28,7 @@ TENSOR_PATTERN = r"torch\..*Tensor" PARAMETER_KEYWORD = "Parameter" +ATTR_SKIP = "_ML_DAIKON_data_ID" # _CACHE_PATH = "func_with_tensors.pkl" @@ -156,6 +157,9 @@ def get_returned_tensors( type_value = list(return_value.keys())[0] attributes = return_value[type_value] if re.match(TENSOR_PATTERN, type_value) or PARAMETER_KEYWORD in type_value: + # let's pop the ATTR_SKIP attribute + if ATTR_SKIP in attributes: + attributes.pop(ATTR_SKIP) returned_tensors.append(attributes) return returned_tensors @@ -265,12 +269,10 @@ class ConsistentOutputRelation(Relation): """ @staticmethod - def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: - + def generate_hypothesis(trace) -> list[Hypothesis]: logger = logging.getLogger(__name__) all_func_names = trace.get_func_names() - relevant_func_call_events = get_events_of_funcs_with_tensors( all_func_names, trace, output_has_tensors=True, input_has_tensors=False ) @@ -376,67 +378,71 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: hypotheses_for_func.append(hypothesis) all_hypotheses[func_name] = hypotheses_for_func - # positive example is the function calls corresponding to the hypothesis - # negative example is the function calls that do not correspond to the hypothesis (i.e. func calls that returned different values) - - # now that we have the hypotheses for each function, we can return them - # we can also return the failed hypotheses if any - # we can also return the properties that are consistent across the function calls - - # infer precondition for these hypotheses - print(all_hypotheses) - - invariants = [] - failed_hypotheses = [] - for func_name, hypotheses in all_hypotheses.items(): - for hypothesis in hypotheses: - precondition = find_precondition(hypothesis, trace) - print(precondition) - if precondition is not None: - hypothesis.invariant.precondition = precondition - invariants.append(hypothesis.invariant) - else: - print(f"Could not find precondition for {hypothesis}") - failed_hypotheses.append(FailedHypothesis(hypothesis)) - - print("done") + # return all_hypotheses + return sum(all_hypotheses.values(), []) - # precondition inference: function name can be a precondition - - # can we let relation tell the precondition inference algorithm about what is already assumed? - # then we solve the step issue. - - # now let's reason about the input and output properties of these function calls' args and return values - - # let's make the assumption that we are only interested in the functions that have tensors as args or return values - - # functions that do not have tensors as input but have them as output --> factory functions - # functions that have tensors as both input and output --> mathematical operations - - # we need an abstraction over the trace to get a specific function call's args and return values - # for now let's get them from the raw json trace + @staticmethod + def collect_examples(trace, hypothesis): + inv = hypothesis.invariant + # get all the function calls + assert len(inv.params) == 2 + assert isinstance(inv.params[0], APIParam) + assert isinstance(inv.params[1], VarTypeParam) - # get all the function calls. + func_name = inv.params[0].api_full_name + # get all the function calls for the function + func_call_ids = trace.get_func_call_ids(func_name) - # find properties inside. + for func_call_id in tqdm( + func_call_ids, desc=f"Adding examples for {inv.text_description}" + ): + func_call_event = trace.query_func_call_event(func_call_id) + if isinstance( + func_call_event, (FuncCallExceptionEvent, IncompleteFuncCallEvent) + ): + continue - # instead of doing this, will it just make more sense to differentiate between stages and only infer in one single step if - # we are in training or testing stage? + returned_tensors = get_returned_tensors(func_call_event) + if len(returned_tensors) == 0: + # add negative example + example = Example({"pre_event": [func_call_event.pre_record]}) + hypothesis.negative_examples.add_example(example) + continue - # Get all the function args and return values, and try to find the properties that - # are consistent across the function calls. + # TODO: this might be wrong due to make hashable used in infer, proceed with caution + for returned_tensor in returned_tensors: + prop = inv.params[1].attr_name + prop_val = inv.params[1].const_value + if prop not in returned_tensor or returned_tensor[prop] != prop_val: + # add negative example + example = Example({"pre_event": [func_call_event.pre_record]}) + hypothesis.negative_examples.add_example(example) + else: + # add positive example + example = Example({"pre_event": [func_call_event.pre_record]}) + hypothesis.positive_examples.add_example(example) - # 1. group by properties if some of them show great statistical consistency + hypothesis.invariant.num_positive_examples = len(hypothesis.positive_examples) + hypothesis.invariant.num_negative_examples = len(hypothesis.negative_examples) - # Question: are u only targeting function args and return values? What about consistency relationships between transient - # variables and the long-term variables (e.g. model, optimizers?) + @staticmethod + def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: - # Insight: internal, transient variables are kinda separate from long-term variables like the model and optimizer. - # I think it will be fine to treat them separately. + all_hypotheses = ConsistentOutputRelation.generate_hypothesis(trace) - # the simplest case: only matmul is called multiple times and you have them both inside and outside the autocast regions + invariants = [] + failed_hypotheses = [] + for hypothesis in all_hypotheses: + precondition = find_precondition(hypothesis, [trace]) + print(precondition) + if precondition is not None: + hypothesis.invariant.precondition = precondition + invariants.append(hypothesis.invariant) + else: + print(f"Could not find precondition for {hypothesis}") + failed_hypotheses.append(FailedHypothesis(hypothesis)) - # need additional properties about these functions in precondition inference + print("done") return invariants, failed_hypotheses @@ -507,6 +513,10 @@ def static_check_all( ) # raise NotImplementedError + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] + class ConsistentInputOutputRelation(Relation): """Infer common properties that should be enforced across the input and output of a function call. @@ -515,22 +525,7 @@ class ConsistentInputOutputRelation(Relation): """ @staticmethod - def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: - # get the function calls that have tensors or nn.Modules as both input and output - - """ - TODO: find all function calls that have tensors or nn.Modules as both input and output - We can extend this to other types of variables as well, but for now let's keep it simple. - - We need instrumentor to dump the trace in a format that we can easily query for this information. - - @Beijie: can we use static analysis for the above stuff? - - For now let's dump every possible information in the trace and then query it. - - # tensor - """ - + def generate_hypothesis(trace: Trace) -> list[Hypothesis]: logger = logging.getLogger(__name__) all_func_names = trace.get_func_names() @@ -545,12 +540,6 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: str, dict[tuple[InputOutputParam, InputOutputParam], Hypothesis] ] = {} - # func_name: [ - # - func_call_id (str): func_call_event - # - func_call_id (str): func_call_event - # - ... - # ] - for func_name in tqdm( relevant_func_call_events, desc="Infer hypotheses for consistent input output relation on functions", @@ -682,17 +671,77 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: hypothesis.negative_examples ) + hypos_to_return: list[Hypothesis] = [] + for hypotheses in all_hypotheses.values(): + hypos_to_return.extend(hypotheses.values()) + + return hypos_to_return + + @staticmethod + def collect_examples(trace, hypothesis): + inv = hypothesis.invariant + + assert len(inv.params) == 3 + + input_param, api_param, output_param = inv.params + + assert isinstance(input_param, InputOutputParam) + assert isinstance(api_param, APIParam) + assert isinstance(output_param, InputOutputParam) + assert inv.params[0].is_input + assert not inv.params[2].is_input + + logger = logging.getLogger(__name__) + + # get all the function calls + func_name = api_param.api_full_name + func_call_ids = trace.get_func_call_ids(func_name) + + for func_call_id in tqdm( + func_call_ids, desc=f"Checking invariant {inv.text_description}" + ): + func_call_event = trace.query_func_call_event(func_call_id) + if isinstance( + func_call_event, (FuncCallExceptionEvent, IncompleteFuncCallEvent) + ): + continue + + input_tensors = get_input_tensors(func_call_event) + output_tensors = get_returned_tensors(func_call_event) + try: + input_value = input_param.get_value_from_list_of_tensors(input_tensors) + output_value = output_param.get_value_from_list_of_tensors( + output_tensors + ) + except (IndexError, KeyError): + logger.warning( + f"Could not find the value to be used in input or output tensors for the hypothesis {inv}, skipping this function call." + ) + continue + + if input_value != output_value: + # add negative example + example = Example({"pre_event": [func_call_event.pre_record]}) + hypothesis.negative_examples.add_example(example) + else: + # add positive example + example = Example({"pre_event": [func_call_event.pre_record]}) + hypothesis.positive_examples.add_example(example) + + @staticmethod + def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: + all_hypotheses = ConsistentInputOutputRelation.generate_hypothesis(trace) + # now that we have the hypotheses for each function, we can start precondition inference invariants = [] failed_hypotheses = [] - for func_name, hypotheses in all_hypotheses.items(): - for hypothesis in hypotheses.values(): - precondition = find_precondition(hypothesis, trace) - if precondition is not None: - hypothesis.invariant.precondition = precondition - invariants.append(hypothesis.invariant) - else: - failed_hypotheses.append(FailedHypothesis(hypothesis)) + for hypothesis in all_hypotheses: + precondition = find_precondition(hypothesis, [trace]) + if precondition is not None: + hypothesis.invariant.precondition = precondition + invariants.append(hypothesis.invariant) + else: + failed_hypotheses.append(FailedHypothesis(hypothesis)) return invariants, failed_hypotheses @@ -765,6 +814,10 @@ def static_check_all(trace, inv, check_relation_first): triggered=triggered, ) + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] + class ThresholdRelation(Relation): """Infer common properties that should be enforced across the input and output of a function call. @@ -773,7 +826,7 @@ class ThresholdRelation(Relation): """ @staticmethod - def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: + def generate_hypothesis(trace: Trace) -> list[Hypothesis]: # get the function calls that have tensors or nn.Modules as both input and output logger = logging.getLogger(__name__) all_func_names = trace.get_func_names() @@ -970,26 +1023,90 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: else: hypothesis.negative_examples.add_example(example) + hypos_to_return: list[Hypothesis] = [] + for hypotheses in min_hypotheses.values(): + hypos_to_return.extend(hypotheses.values()) + for hypotheses in max_hypotheses.values(): + hypos_to_return.extend(hypotheses.values()) + return hypos_to_return + + @staticmethod + def collect_examples(trace, hypothesis): + inv = hypothesis.invariant + + assert len(inv.params) == 3 + max_param, api_param, min_param = inv.params + assert isinstance(max_param, InputOutputParam) + assert isinstance(api_param, APIParam) + assert isinstance(min_param, InputOutputParam) + + if max_param.is_input: + assert not min_param.is_input + is_threshold_min = False + input_param = max_param + output_param = min_param + else: + assert min_param.is_input + is_threshold_min = True + input_param = min_param + output_param = max_param + + func_name = api_param.api_full_name + # get all function calls for the function + func_call_ids = trace.get_func_call_ids(func_name) + for func_call_id in tqdm( + func_call_ids, desc=f"Checking invariant {inv.text_description}" + ): + func_call_event = trace.query_func_call_event(func_call_id) + if isinstance( + func_call_event, (FuncCallExceptionEvent, IncompleteFuncCallEvent) + ): + continue + + threshold_value = input_param.get_value_from_arguments( + Arguments( + func_call_event.args, + func_call_event.kwargs, + func_call_event.func_name, + consider_default_values=True, + ) + ) + output_value = output_param.get_value_from_list_of_tensors( + get_returned_tensors(func_call_event) + ) + + example = Example({"pre_event": [func_call_event.pre_record]}) + if is_threshold_min: + if output_value >= threshold_value: + # add positive example + hypothesis.positive_examples.add_example(example) + else: + # add negative example + hypothesis.negative_examples.add_example(example) + + else: + if output_value <= threshold_value: + # add positive example + hypothesis.positive_examples.add_example(example) + else: + # add negative example + hypothesis.negative_examples.add_example(example) + + @staticmethod + def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: # now that we have the hypotheses for each function, we can start precondition inference + all_hypotheses = ThresholdRelation.generate_hypothesis(trace) + invariants = [] failed_hypotheses = [] - for func_name, hypotheses in min_hypotheses.items(): - for hypothesis in hypotheses.values(): - precondition = find_precondition(hypothesis, trace) - if precondition is not None: - hypothesis.invariant.precondition = precondition - invariants.append(hypothesis.invariant) - else: - failed_hypotheses.append(FailedHypothesis(hypothesis)) - - for func_name, hypotheses in max_hypotheses.items(): - for hypothesis in hypotheses.values(): - precondition = find_precondition(hypothesis, trace) - if precondition is not None: - hypothesis.invariant.precondition = precondition - invariants.append(hypothesis.invariant) - else: - failed_hypotheses.append(FailedHypothesis(hypothesis)) + for hypothesis in all_hypotheses: + precondition = find_precondition(hypothesis, [trace]) + if precondition is not None: + hypothesis.invariant.precondition = precondition + invariants.append(hypothesis.invariant) + else: + failed_hypotheses.append(FailedHypothesis(hypothesis)) + return invariants, failed_hypotheses @staticmethod @@ -1074,3 +1191,7 @@ def static_check_all(trace, inv, check_relation_first): check_passed=True, triggered=triggered, ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/contain_relation.py b/mldaikon/invariant/contain_relation.py index b135793f..a8b28606 100644 --- a/mldaikon/invariant/contain_relation.py +++ b/mldaikon/invariant/contain_relation.py @@ -308,25 +308,150 @@ class APIContainRelation(Relation): """ @staticmethod - def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: + def generate_hypothesis(trace) -> list[Hypothesis]: + # let's play it dumb here, + _, _, hypotheses = APIContainRelation.infer( + trace, return_successful_hypotheses=True + ) # type: ignore + + # set all precond to None + for hypo in hypotheses: + hypo.invariant.precondition = None + + return hypotheses + + @staticmethod + def collect_examples(trace, hypothesis): + inv = hypothesis.invariant + assert ( + len(inv.params) == 2 + ), "Expected 2 parameters for APIContainRelation, one for the parent function name, and one for the child event name" + parent_param, child_param = inv.params[0], inv.params[1] + assert isinstance( + parent_param, APIParam + ), "Expected the first parameter to be an APIParam" + assert isinstance( + child_param, (APIParam, VarTypeParam, VarNameParam) + ), "Expected the second parameter to be an APIParam or VarTypeParam (VarNameParam not supported yet)" + + parent_func_name = parent_param.api_full_name + + parent_func_call_ids = trace.get_func_call_ids( + parent_func_name + ) # should be sorted by time to reflect timeliness + + check_for_unchanged_vars = False + if not isinstance(child_param, APIParam): + if VAR_GROUP_NAME in hypothesis.negative_examples.get_group_names(): + check_for_unchanged_vars = True + + for parent_func_call_id in tqdm( + parent_func_call_ids, desc=f"Collecting examples for {inv.text_description}" + ): + contained_events = events_scanner(trace, parent_func_call_id) + grouped_events = _group_events_by_type(contained_events) + if isinstance(child_param, APIParam): + contained_events = ( + grouped_events.get(FuncCallEvent, []) + + grouped_events.get(FuncCallExceptionEvent, []) + + grouped_events.get(IncompleteFuncCallEvent, []) + ) + for event in contained_events: + if child_param.check_event_match(event): + # add a positive example + parent_pre_record = trace.get_pre_func_call_record( + parent_func_call_id + ) + example = Example() + example.add_group(PARENT_GROUP_NAME, [parent_pre_record]) + hypothesis.positive_examples.add_example(example) + break + else: + # add a negative example + parent_pre_record = trace.get_pre_func_call_record( + parent_func_call_id + ) + example = Example() + example.add_group(PARENT_GROUP_NAME, [parent_pre_record]) + hypothesis.negative_examples.add_example(example) + else: + contained_events = grouped_events.get(VarChangeEvent, []) + for event in contained_events: + found = False + if child_param.check_event_match(event): + # add a positive example + parent_pre_record = trace.get_pre_func_call_record( + parent_func_call_id + ) + example = Example() + example.add_group(PARENT_GROUP_NAME, [parent_pre_record]) + example.add_group(VAR_GROUP_NAME, event.get_traces()) + hypothesis.positive_examples.add_example(example) + found = True + if check_for_unchanged_vars: + assert ( + found + ), "Expected the positive example to be found in the case of dynamic analysis" + unchanged_var_ids = ( + trace.get_var_ids_unchanged_but_causally_related( + parent_func_call_id, + child_param.var_type, + child_param.attr_name, + ) + ) + # add these unchanged vars as negative examples + for var_id in unchanged_var_ids: + example = Example() + example.add_group(PARENT_GROUP_NAME, [parent_pre_record]) + example.add_group( + VAR_GROUP_NAME, + trace.get_var_raw_event_before_time( + var_id, parent_pre_record["time"] + ), + ) + hypothesis.negative_examples.add_example(example) + + if not found: + # add a negative example + parent_pre_record = trace.get_pre_func_call_record( + parent_func_call_id + ) + example = Example() + example.add_group(PARENT_GROUP_NAME, [parent_pre_record]) + hypothesis.negative_examples.add_example(example) + + @staticmethod + def infer( + trace: Trace, return_successful_hypotheses: bool = False + ) -> tuple[list[Invariant], list[FailedHypothesis]]: """Infer Invariants without Preconditions""" # enable stage-based inference for API Contain Relation logger = logging.getLogger(__name__) if not trace.is_stage_annotated(): - return APIContainRelation._infer(trace) + invs, failed_hypos, succ_hypos = APIContainRelation._infer(trace) + if return_successful_hypotheses: + return invs, failed_hypos, succ_hypos # type: ignore + return invs, failed_hypos invariants_by_stage: dict[Invariant, list[str]] = {} + invariants_to_hypos: dict[Invariant, list[Hypothesis]] = {} failed_hypothesis_by_stage: dict[FailedHypothesis, list[str]] = {} for stage, stage_trace in trace.get_traces_for_stage().items(): logger.info( "Stage annotation detected in trace, enabling stage-based inference" ) - invariants, failed_hypotheses = APIContainRelation._infer(stage_trace) - for invariant in invariants: + invariants, failed_hypotheses, successful_hypotheses = ( + APIContainRelation._infer(stage_trace) + ) + for invariant, hypo in zip(invariants, successful_hypotheses): if invariant not in invariants_by_stage: invariants_by_stage[invariant] = [] + if invariant not in invariants_to_hypos: + invariants_to_hypos[invariant] = [] invariants_by_stage[invariant].append(stage) + invariants_to_hypos[invariant].append(hypo) + for failed_hypothesis in failed_hypotheses: if failed_hypothesis not in failed_hypothesis_by_stage: failed_hypothesis_by_stage[failed_hypothesis] = [] @@ -346,6 +471,18 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: invariant.precondition.add_stage_info(set(supported_stages)) merged_invariants.append(invariant) + # for all the invariants' hypotheses, merge into one hypothesis by adding the examples + merged_successful_hypotheses: list[Hypothesis] = [] + for hypotheses in invariants_to_hypos.values(): + hypothesis = hypotheses[0] + for other_hypothesis in hypotheses[1:]: + hypothesis.positive_examples.examples.extend( + other_hypothesis.positive_examples.examples + ) + hypothesis.negative_examples.examples.extend( + other_hypothesis.negative_examples.examples + ) + # HACK: add the stage info to the failed hypotheses through the text description as well NEED TO CHANGE IF WE SUPPORT INVARIANT REFINEMENT for failed_hypothesis in failed_hypothesis_by_stage: not_supported_stages = failed_hypothesis_by_stage[failed_hypothesis] @@ -355,10 +492,14 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: f" (FAILED IN in stages: {not_supported_stages})" ) + if return_successful_hypotheses: + return merged_invariants, list(failed_hypothesis_by_stage.keys()), merged_successful_hypotheses # type: ignore return merged_invariants, list(failed_hypothesis_by_stage.keys()) @staticmethod - def _infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: + def _infer( + trace: Trace, + ) -> tuple[list[Invariant], list[FailedHypothesis], list[Hypothesis]]: """Infer Invariants with Preconditions""" logger = logging.getLogger(__name__) @@ -382,7 +523,7 @@ def _infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: logger.warning( "No function calls found in the trace, skipping the analysis" ) - return [], [] + return [], [], [] for parent in tqdm( func_names, desc="Scanning through function calls to generate hypotheses" @@ -516,7 +657,6 @@ def _infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: positive_examples=ExampleList({PARENT_GROUP_NAME}), negative_examples=ExampleList({PARENT_GROUP_NAME}), ) - # else: # MARK: EVENT MERGING TO CREATE FINE-GRAINED CHILD PARAMS for parent_param in passing_events_for_API_contain_API: @@ -683,9 +823,11 @@ def _merge_child_API_events( # for each key in all_mergeable_hypotheses, invoke the hypotheses merging process. for hypotheses_to_be_merged in all_mergeable_hypotheses.values(): - merged_hypotheses = _merge_hypotheses(hypotheses_to_be_merged) + merged_successful_hypotheses = _merge_hypotheses( + hypotheses_to_be_merged + ) # delete original hypotheses in the original `hypotheses` structure - for hypo in merged_hypotheses: + for hypo in merged_successful_hypotheses: new_child_param = hypo.invariant.params[1] assert isinstance(new_child_param, (VarNameParam | VarTypeParam)) hypotheses[parent_param][ @@ -704,6 +846,7 @@ def _merge_child_API_events( ) all_invariants: list[Invariant] = [] failed_hypotheses = [] + successful_hypotheses = [] for parent_param in hypotheses: for child_param in hypotheses[parent_param]: hypo = hypotheses[parent_param][child_param] @@ -711,7 +854,7 @@ def _merge_child_API_events( logger.debug( f"Starting the inference of precondition for the hypotheses: {hypo.invariant.text_description}" ) - found_precondition = find_precondition(hypo, trace) + found_precondition = find_precondition(hypo, [trace]) if ( found_precondition is not None ): # TODO: abstract this precondition inference part to a function @@ -719,11 +862,12 @@ def _merge_child_API_events( hypo.invariant.num_positive_examples = len(hypo.positive_examples) hypo.invariant.num_negative_examples = len(hypo.negative_examples) all_invariants.append(hypo.invariant) + successful_hypotheses.append(hypo) else: logger.debug(f"Precondition not found for the hypotheses: {hypo}") failed_hypotheses.append(FailedHypothesis(hypo)) - return all_invariants, failed_hypotheses + return all_invariants, failed_hypotheses, successful_hypotheses @staticmethod def evaluate(value_group: list) -> bool: @@ -932,3 +1076,7 @@ def static_check_all( check_passed=True, triggered=inv_triggered, ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/cover_relation.py b/mldaikon/invariant/cover_relation.py index d77ed912..0273a40a 100644 --- a/mldaikon/invariant/cover_relation.py +++ b/mldaikon/invariant/cover_relation.py @@ -63,11 +63,11 @@ def is_subset(path1: List[APIParam], path2: List[APIParam]) -> bool: def add_path(new_path: List[APIParam]) -> None: nonlocal paths - for existing_path in paths[:]: - if is_subset(existing_path, new_path): - paths.remove(existing_path) - if is_subset(new_path, existing_path): - return + # for existing_path in paths[:]: + # if is_subset(existing_path, new_path): + # paths.remove(existing_path) + # if is_subset(new_path, existing_path): + # return paths.append(new_path) def dfs(node: APIParam, path: List[APIParam], visited: Set[APIParam]) -> None: @@ -263,7 +263,9 @@ def check_same_level(funcA: str, funcB: str, process_id: str, thread_id: str): logger.debug( f"Finding Precondition for {hypo}: {hypothesis_with_examples[hypo].invariant.text_description}" ) - preconditions = find_precondition(hypothesis_with_examples[hypo], trace) + preconditions = find_precondition( + hypothesis_with_examples[hypo], [trace] + ) logger.debug(f"Preconditions for {hypo}:\n{str(preconditions)}") if preconditions is not None: @@ -350,7 +352,7 @@ def dp_merge( if pair not in precondition_cache: precondition_cache[pair] = find_precondition( hypothesis_with_examples[(a.api_full_name, b.api_full_name)], - trace, + [trace], ) current_precondition = precondition_cache[pair] @@ -374,7 +376,7 @@ def dp_merge( next_pair[1].api_full_name, ) ], - trace, + [trace], ) next_precondition = precondition_cache[next_pair] @@ -597,3 +599,7 @@ def check_same_level(funcA: str, funcB: str, process_id, thread_id): check_passed=True, triggered=inv_triggered, ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/cover_relation_refactor.py b/mldaikon/invariant/cover_relation_refactor.py new file mode 100644 index 00000000..5aeb1ced --- /dev/null +++ b/mldaikon/invariant/cover_relation_refactor.py @@ -0,0 +1,607 @@ +import logging +from itertools import permutations +from typing import Any, Dict, List, Set, Tuple + +from tqdm import tqdm + +from mldaikon.invariant.base_cls import ( + APIParam, + CheckerResult, + Example, + ExampleList, + FailedHypothesis, + GroupedPreconditions, + Hypothesis, + Invariant, + Relation, +) +from mldaikon.invariant.lead_relation import ( + get_func_data_per_PT, + get_func_names_to_deal_with, +) +from mldaikon.invariant.precondition import find_precondition +from mldaikon.trace.trace import Trace + +EXP_GROUP_NAME = "func_cover" + + +def is_complete_subgraph( + path: List[APIParam], new_node: APIParam, graph: Dict[APIParam, List[APIParam]] +) -> bool: + """Check if adding new_node to path forms a complete (directed) graph.""" + for node in path: + if new_node not in graph[node]: + return False + return True + + +def merge_relations(pairs: List[Tuple[APIParam, APIParam]]) -> List[List[APIParam]]: + graph: Dict[APIParam, List[APIParam]] = {} + indegree: Dict[APIParam, int] = {} + for a, b in pairs: + if a not in graph: + graph[a] = [] + if b not in graph: + graph[b] = [] + graph[a].append(b) + + if b in indegree: + indegree[b] += 1 + else: + indegree[b] = 1 + + if a not in indegree: + indegree[a] = 0 + + start_nodes: List[APIParam] = [node for node in indegree if indegree[node] == 0] + paths: List[List[APIParam]] = [] + + def is_subset(path1: List[APIParam], path2: List[APIParam]) -> bool: + return set(path1).issubset(set(path2)) + + def add_path(new_path: List[APIParam]) -> None: + nonlocal paths + # for existing_path in paths[:]: + # if is_subset(existing_path, new_path): + # paths.remove(existing_path) + # if is_subset(new_path, existing_path): + # return + paths.append(new_path) + + def dfs(node: APIParam, path: List[APIParam], visited: Set[APIParam]) -> None: + path.append(node) + visited.add(node) + if node in graph: + for neighbor in graph[node]: + if neighbor not in visited and is_complete_subgraph( + path, neighbor, graph + ): + dfs(neighbor, path, visited) + if not graph.get(node): + add_path(path.copy()) + path.pop() + visited.remove(node) + + for start_node in start_nodes: + dfs(start_node, [], set()) + + return paths + + +class FunctionCoverRelation(Relation): + """FunctionCoverRelation is a relation that checks if one function covers another function. + + say function A and function B are two functions in the trace, we say function A covers function B when + every time function B is called, a function A invocation exists before it. + """ + + @staticmethod + def generate_hypothesis(trace) -> list[Hypothesis]: + """Generate hypothesis for the FunctionCoverRelation on trace.""" + + logger = logging.getLogger(__name__) + + # 1. Pre-process all the events + print("Start preprocessing....") + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} + function_pool: Set[Any] = set() + + # If the trace contains no function, return [] + function_pool = set(get_func_names_to_deal_with(trace)) + if len(function_pool) == 0: + logger.warning( + "No relevant function calls found in the trace, skipping the analysis" + ) + return [] + + function_times, function_id_map, listed_events = get_func_data_per_PT( + trace, function_pool + ) + print("End preprocessing") + + # 2. Check if two function on the same level for each thread and process + def check_same_level(funcA: str, funcB: str, process_id: str, thread_id: str): + if funcA == funcB: + return False + + if funcB not in function_id_map[(process_id, thread_id)]: + return False + + if funcA not in function_id_map[(process_id, thread_id)]: + return True + + for idA in function_id_map[(process_id, thread_id)][funcA]: + for idB in function_id_map[(process_id, thread_id)][funcB]: + preA = function_times[(process_id, thread_id)][idA]["start"] + postA = function_times[(process_id, thread_id)][idA]["end"] + preB = function_times[(process_id, thread_id)][idB]["start"] + postB = function_times[(process_id, thread_id)][idB]["end"] + if preB >= postA: + break + if postB <= preA: + continue + return False + return True + + print("Start same level checking...") + same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} + valid_relations: Dict[Tuple[str, str], bool] = {} + + for (process_id, thread_id), _ in tqdm( + listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + ): + same_level_func[(process_id, thread_id)] = {} + for funcA, funcB in tqdm( + permutations(function_pool, 2), + ascii=True, + leave=True, + desc="Combinations Checked", + ): + if check_same_level(funcA, funcB, process_id, thread_id): + if funcA not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][funcA] = [] + same_level_func[(process_id, thread_id)][funcA].append(funcB) + valid_relations[(funcA, funcB)] = True + print("End same level checking") + + # 3. Generating hypothesis + print("Start generating hypo...") + hypothesis_with_examples = { + (func_A, func_B): Hypothesis( + invariant=Invariant( + relation=FunctionCoverRelation, + params=[ + APIParam(func_A), + APIParam(func_B), + ], + precondition=None, + text_description=f"FunctionCoverRelation between {func_A} and {func_B}", + ), + positive_examples=ExampleList({EXP_GROUP_NAME}), + negative_examples=ExampleList({EXP_GROUP_NAME}), + ) + for (func_A, func_B), _ in valid_relations.items() + } + print("End generating hypo") + + # 4. Add positive and negative examples + print("Start adding examples...") + for (process_id, thread_id), events_list in tqdm( + listed_events.items(), ascii=True, leave=True, desc="Group" + ): + + for (func_A, func_B), _ in tqdm( + valid_relations.items(), + desc="Function Pair", + ): + + if func_A not in same_level_func[(process_id, thread_id)]: + continue + + if func_B not in same_level_func[(process_id, thread_id)][func_A]: + continue + + flag_A = None + flag_B = None + pre_record_A = [] + pre_record_B = [] + + for event in events_list: + if event["type"] != "function_call (pre)": + continue + + if func_A == event["function"]: + flag_A = event["time"] + flag_B = None + pre_record_A = [event] + + if func_B == event["function"]: + if flag_B is not None: + valid_relations[(func_A, func_B)] = False + neg = Example() + neg.add_group(EXP_GROUP_NAME, pre_record_B) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(neg) + pre_record_B = [event] + flag_B = event["time"] + continue + + flag_B = event["time"] + if flag_A is None: + valid_relations[(func_A, func_B)] = False + neg = Example() + neg.add_group(EXP_GROUP_NAME, [event]) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(neg) + else: + pos = Example() + pos.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis_with_examples[ + (func_A, func_B) + ].positive_examples.add_example(pos) + + pre_record_B = [event] + print("End adding examples") + + return list(hypothesis_with_examples.values()) + + @staticmethod + def collect_examples(trace, hypothesis): + """Generate examples for a hypothesis on trace.""" + inv = hypothesis.invariant + + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} + + # If the trace contains no function, return orginal hypothesis + func_names = trace.get_func_names() + if len(func_names) == 0: + print("No function calls found in the trace, skipping the collecting") + return + + function_pool = ( + [] + ) # Here function_pool only contains functions existing in given invariant + + invariant_length = len(inv.params) + for i in range(invariant_length): + func = inv.params[i] + assert isinstance( + func, APIParam + ), "Invariant parameters should be APIParam." + function_pool.append(func.api_full_name) + + function_pool = list(set(function_pool).intersection(func_names)) + + if len(function_pool) == 0: + print( + "No relevant function calls found in the trace, skipping the collecting" + ) + return + + print("Start fetching data for collecting...") + function_times, function_id_map, listed_events = get_func_data_per_PT( + trace, function_pool + ) + print("End fetching data for collecting...") + + def check_same_level(funcA: str, funcB: str, process_id, thread_id): + if funcA == funcB: + return False + + if funcB not in function_id_map[(process_id, thread_id)]: + return False + + if funcA not in function_id_map[(process_id, thread_id)]: + return True + + for idA in function_id_map[(process_id, thread_id)][funcA]: + for idB in function_id_map[(process_id, thread_id)][funcB]: + preA = function_times[(process_id, thread_id)][idA]["start"] + postA = function_times[(process_id, thread_id)][idA]["end"] + preB = function_times[(process_id, thread_id)][idB]["start"] + postB = function_times[(process_id, thread_id)][idB]["end"] + if preB >= postA: + break + if postB <= preA: + continue + return False + return True + + print("Starting collecting iteration...") + for i in tqdm(range(invariant_length - 1)): + func_A = inv.params[i] + func_B = inv.params[i + 1] + + assert isinstance(func_A, APIParam) and isinstance( + func_B, APIParam + ), "Invariant parameters should be string." + + for (process_id, thread_id), events_list in listed_events.items(): + funcA = func_A.api_full_name + funcB = func_B.api_full_name + + if not check_same_level(funcA, funcB, process_id, thread_id): + continue + + # check + flag_A = None + flag_B = None + pre_record_A = [] + pre_record_B = [] + + for event in events_list: + if event["type"] != "function_call (pre)": + continue + + if funcA == event["function"]: + flag_A = event["time"] + flag_B = None + pre_record_A = [event] + + if funcB == event["function"]: + if flag_B is not None: + neg = Example() + neg.add_group(EXP_GROUP_NAME, pre_record_B) + hypothesis.negative_examples.add_example(neg) + pre_record_B = [event] + flag_B = event["time"] + continue + + flag_B = event["time"] + if flag_A is None: + neg = Example() + neg.add_group(EXP_GROUP_NAME, [event]) + hypothesis.negative_examples.add_example(neg) + else: + pos = Example() + pos.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis.positive_examples.add_example(pos) + + pre_record_B = [event] + + print("End collecting iteration...") + + @staticmethod + def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: + """Infer Invariants for the FunctionCoverRelation.""" + + all_hypotheses = FunctionCoverRelation.generate_hypothesis(trace) + + # for hypothesis in all_hypotheses: + # FunctionCoverRelation.collect_examples(trace, hypothesis) + + if_merge = True + + print("Start precondition inference...") + failed_hypothesis = [] + for hypothesis in all_hypotheses.copy(): + preconditions = find_precondition( + hypothesis, trace + ) + if preconditions is not None: + hypothesis.invariant.precondition = ( + preconditions + ) + else: + failed_hypothesis.append( + FailedHypothesis(hypothesis) + ) + all_hypotheses.remove(hypothesis) + print("End precondition inference") + + if not if_merge: + return ( + list( + [hypo.invariant for hypo in all_hypotheses] + ), + failed_hypothesis, + ) + print("End precondition inference") + + # 6. Merge invariants + print("Start merging invariants...") + relation_pool: Dict[ + GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] + ] = {} + # relation_pool contains all binary relations classified by GroupedPreconditions (key) + for hypothesis in all_hypotheses: + param0 = hypothesis.invariant.params[0] + param1 = hypothesis.invariant.params[1] + + assert(isinstance(param0, APIParam) and isinstance(param1, APIParam)) + if ( + hypothesis.invariant.precondition + not in relation_pool + ): + relation_pool[ + hypothesis.invariant.precondition + ] = [] + relation_pool[ + hypothesis.invariant.precondition + ].append((param0, param1)) + + merged_relations: Dict[ + GroupedPreconditions | None, List[List[APIParam]] + ] = {} + + for key, values in tqdm(relation_pool.items(), desc="Merging Invariants"): + merged_relations[key] = merge_relations(values) + + merged_ininvariants = [] + + for key, merged_values in merged_relations.items(): + for merged_value in merged_values: + new_invariant = Invariant( + relation=FunctionCoverRelation, + params=[param for param in merged_value], + precondition=key, + text_description="Merged FunctionCoverRelation in Ordered List", + ) + merged_ininvariants.append(new_invariant) + print("End merging invariants") + + return merged_ininvariants, failed_hypothesis + + @staticmethod + def evaluate(value_group: list) -> bool: + """Given a group of values, should return a boolean value + indicating whether the relation holds or not. + + args: + value_group: list + A list of values to evaluate the relation on. The length of the list + should be equal to the number of variables in the relation. + """ + return True + + @staticmethod + def static_check_all( + trace: Trace, inv: Invariant, check_relation_first: bool + ) -> CheckerResult: + """Given a trace and an invariant, should return a boolean value + indicating whether the invariant holds on the trace. + + args: + trace: Trace + A trace to check the invariant on. + inv: Invariant + The invariant to check on the trace. + """ + assert inv.precondition is not None, "Invariant should have a precondition." + + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} + + inv_triggered = False + # If the trace contains no function, return vacuous true result + func_names = trace.get_func_names() + if len(func_names) == 0: + print("No function calls found in the trace, skipping the checking") + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + function_pool = ( + [] + ) # Here function_pool only contains functions existing in given invariant + + invariant_length = len(inv.params) + for i in range(invariant_length): + func = inv.params[i] + assert isinstance( + func, APIParam + ), "Invariant parameters should be APIParam." + function_pool.append(func.api_full_name) + + function_pool = list(set(function_pool).intersection(func_names)) + + # YUXUAN ASK: if function_pool is not stictly subset of func_names, should we directly return false? + + if len(function_pool) == 0: + print( + "No relevant function calls found in the trace, skipping the checking" + ) + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + print("Start fetching data for checking...") + function_times, function_id_map, listed_events = get_func_data_per_PT( + trace, function_pool + ) + print("End fetching data for checking...") + + def check_same_level(funcA: str, funcB: str, process_id, thread_id): + if funcA == funcB: + return False + + if funcB not in function_id_map[(process_id, thread_id)]: + return False + + if funcA not in function_id_map[(process_id, thread_id)]: + return True + + for idA in function_id_map[(process_id, thread_id)][funcA]: + for idB in function_id_map[(process_id, thread_id)][funcB]: + preA = function_times[(process_id, thread_id)][idA]["start"] + postA = function_times[(process_id, thread_id)][idA]["end"] + preB = function_times[(process_id, thread_id)][idB]["start"] + postB = function_times[(process_id, thread_id)][idB]["end"] + if preB >= postA: + break + if postB <= preA: + continue + return False + return True + + print("Starting checking iteration...") + for i in tqdm(range(invariant_length - 1)): + func_A = inv.params[i] + func_B = inv.params[i + 1] + + assert isinstance(func_A, APIParam) and isinstance( + func_B, APIParam + ), "Invariant parameters should be string." + + for (process_id, thread_id), events_list in listed_events.items(): + funcA = func_A.api_full_name + funcB = func_B.api_full_name + + if not check_same_level(funcA, funcB, process_id, thread_id): + continue + + # check + flag_B = None + pre_recordB = None + for event in events_list: + if event["type"] != "function_call (pre)": + continue + + if funcA == event["function"]: + flag_B = None + pre_recordB = None + + if funcB == event["function"]: + if flag_B is not None: + if inv.precondition.verify([events_list], EXP_GROUP_NAME): + inv_triggered = True + print( + "The relation " + + funcA + + " covers " + + funcB + + " is violated!\n" + ) + return CheckerResult( + trace=[pre_recordB, event], + invariant=inv, + check_passed=False, + triggered=True, + ) + flag_B = event["time"] + pre_recordB = event + + # FIXME: triggered is always False for passing invariants + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=inv_triggered, + ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/lead_relation.py b/mldaikon/invariant/lead_relation.py index 60504c7b..b2384990 100644 --- a/mldaikon/invariant/lead_relation.py +++ b/mldaikon/invariant/lead_relation.py @@ -146,11 +146,11 @@ def is_subset(path1: List[APIParam], path2: List[APIParam]) -> bool: def add_path(new_path: List[APIParam]) -> None: nonlocal paths - for existing_path in paths[:]: - if is_subset(existing_path, new_path): - paths.remove(existing_path) - if is_subset(new_path, existing_path): - return + # for existing_path in paths[:]: + # if is_subset(existing_path, new_path): + # paths.remove(existing_path) + # if is_subset(new_path, existing_path): + # return paths.append(new_path) def dfs(node: APIParam, path: List[APIParam], visited: Set[APIParam]) -> None: @@ -352,7 +352,9 @@ def check_same_level(funcA: str, funcB: str, process_id: str, thread_id: str): logger.debug( f"Finding Precondition for {hypo}: {hypothesis_with_examples[hypo].invariant.text_description}" ) - preconditions = find_precondition(hypothesis_with_examples[hypo], trace) + preconditions = find_precondition( + hypothesis_with_examples[hypo], [trace] + ) logger.debug(f"Preconditions for {hypo}:\n{str(preconditions)}") if preconditions is not None: @@ -438,7 +440,7 @@ def dp_merge( if pair not in precondition_cache: precondition_cache[pair] = find_precondition( hypothesis_with_examples[(a.api_full_name, b.api_full_name)], - trace, + [trace], ) current_precondition = precondition_cache[pair] @@ -462,7 +464,7 @@ def dp_merge( next_pair[1].api_full_name, ) ], - trace, + [trace], ) next_precondition = precondition_cache[next_pair] @@ -686,3 +688,7 @@ def check_same_level(funcA: str, funcB: str, process_id, thread_id): check_passed=True, triggered=inv_triggered, ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/lead_relation_refactor.py b/mldaikon/invariant/lead_relation_refactor.py new file mode 100644 index 00000000..a19bb6ec --- /dev/null +++ b/mldaikon/invariant/lead_relation_refactor.py @@ -0,0 +1,694 @@ +import logging +from itertools import permutations +from typing import Any, Dict, Iterable, List, Set, Tuple + +from tqdm import tqdm + +from mldaikon.invariant.base_cls import ( + APIParam, + CheckerResult, + Example, + ExampleList, + FailedHypothesis, + GroupedPreconditions, + Hypothesis, + IncompleteFuncCallEvent, + Invariant, + Relation, +) +from mldaikon.invariant.precondition import find_precondition +from mldaikon.trace.trace import Trace + +EXP_GROUP_NAME = "func_lead" +MAX_FUNC_NUM_CONSECUTIVE_CALL = 6 # ideally this should be proportional to the number of training and testing iterations in the trace + + +def get_func_names_to_deal_with(trace: Trace) -> List[str]: + """Get all functions in the trace.""" + function_pool: Set[str] = set() + + # get all functions in the trace + all_func_names = trace.get_func_names() + + # filtering 1: remove private functions + for func_name in all_func_names: + if "._" in func_name: + continue + function_pool.add(func_name) + + # filtering 2: remove functions that have consecutive calls less than FUNC_CALL_FILTERING_THRESHOLD + for func_name in function_pool.copy(): + max_num_consecutive_call = trace.get_max_num_consecutive_call_func(func_name) + if max_num_consecutive_call > MAX_FUNC_NUM_CONSECUTIVE_CALL: + function_pool.remove(func_name) + + return list(function_pool) + + +def get_func_data_per_PT(trace: Trace, function_pool: Iterable[str]): + """ + Get + 1. all function timestamps per process and thread. + 2. all function ids per process and thread. + 3. all events per process and thread. + + # see below code for the structure of the return values + + """ + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = ( + {} + ) # map from (process_id, thread_id) to function call id to start and end time and function name + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = ( + {} + ) # map from (process_id, thread_id) to function name to function call ids + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = ( + {} + ) # map from (process_id, thread_id) to all events + # for all func_ids, get their corresponding events + for func_name in function_pool: + func_call_ids = trace.get_func_call_ids(func_name) + for func_call_id in func_call_ids: + event = trace.query_func_call_event(func_call_id) + assert not isinstance( + event, IncompleteFuncCallEvent + ), "why would we hypothesize on incomplete events (incomplete func calls are typically outermost functions)?" + process_id = event.pre_record["process_id"] + thread_id = event.pre_record["thread_id"] + + # populate the function_times + if (process_id, thread_id) not in function_times: + function_times[(process_id, thread_id)] = {} + + function_times[(process_id, thread_id)][func_call_id] = { + "start": event.pre_record["time"], + "end": event.post_record["time"], + "function": func_name, + } + + # populate the function_id_map + if (process_id, thread_id) not in function_id_map: + function_id_map[(process_id, thread_id)] = {} + if func_name not in function_id_map[(process_id, thread_id)]: + function_id_map[(process_id, thread_id)][func_name] = [] + function_id_map[(process_id, thread_id)][func_name].append(func_call_id) + + # populate the listed_events + if (process_id, thread_id) not in listed_events: + listed_events[(process_id, thread_id)] = [] + listed_events[(process_id, thread_id)].extend( + [event.pre_record, event.post_record] + ) + + # sort the listed_events + for (process_id, thread_id), events_list in listed_events.items(): + listed_events[(process_id, thread_id)] = sorted( + events_list, key=lambda x: x["time"] + ) + + return function_times, function_id_map, listed_events + + +def is_complete_subgraph( + path: List[APIParam], new_node: APIParam, graph: Dict[APIParam, List[APIParam]] +) -> bool: + """Check if adding new_node to path forms a complete (directed) graph.""" + for node in path: + if new_node not in graph[node]: + return False + return True + + +def merge_relations(pairs: List[Tuple[APIParam, APIParam]]) -> List[List[APIParam]]: + graph: Dict[APIParam, List[APIParam]] = {} + indegree: Dict[APIParam, int] = {} + + for a, b in pairs: + if a not in graph: + graph[a] = [] + if b not in graph: + graph[b] = [] + graph[a].append(b) + + if b in indegree: + indegree[b] += 1 + else: + indegree[b] = 1 + + if a not in indegree: + indegree[a] = 0 + + start_nodes: List[APIParam] = [node for node in indegree if indegree[node] == 0] + + paths: List[List[APIParam]] = [] + + def is_subset(path1: List[APIParam], path2: List[APIParam]) -> bool: + return set(path1).issubset(set(path2)) + + def add_path(new_path: List[APIParam]) -> None: + nonlocal paths + # for existing_path in paths[:]: + # if is_subset(existing_path, new_path): + # paths.remove(existing_path) + # if is_subset(new_path, existing_path): + # return + paths.append(new_path) + + def dfs(node: APIParam, path: List[APIParam], visited: Set[APIParam]) -> None: + path.append(node) + visited.add(node) + if node in graph: + for neighbor in graph[node]: + if neighbor not in visited and is_complete_subgraph( + path, neighbor, graph + ): + dfs(neighbor, path, visited) + if not graph.get(node): + add_path(path.copy()) + path.pop() + visited.remove(node) + + for start_node in start_nodes: + dfs(start_node, [], set()) + + return paths + + +class FunctionLeadRelation(Relation): + """FunctionLeadRelation is a relation that checks if one function covers another function. + + say function A and function B are two functions in the trace, we say function A covers function B when + every time function A is called, a function B invocation follows. + """ + @staticmethod + def generate_hypothesis(trace) -> list[Hypothesis]: + """Generate hypothesis for the FunctionLeadRelation on trace.""" + logger = logging.getLogger(__name__) + + # 1. Pre-process all the events + print("Start preprocessing....") + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} + function_pool: Set[Any] = set() + + # If the trace contains no function, safely exists infer process + function_pool = set(get_func_names_to_deal_with(trace)) + if len(function_pool) == 0: + logger.warning( + "No relevant function calls found in the trace, skipping the analysis" + ) + return [] + + function_times, function_id_map, listed_events = get_func_data_per_PT( + trace, function_pool + ) + print("End preprocessing") + + # 2. Check if two function on the same level for each thread and process + def check_same_level(funcA: str, funcB: str, process_id: str, thread_id: str): + if funcA == funcB: + return False + + if funcA not in function_id_map[(process_id, thread_id)]: + return False + + if funcB not in function_id_map[(process_id, thread_id)]: + return True + + for idA in function_id_map[(process_id, thread_id)][funcA]: + for idB in function_id_map[(process_id, thread_id)][funcB]: + preA = function_times[(process_id, thread_id)][idA]["start"] + postA = function_times[(process_id, thread_id)][idA]["end"] + preB = function_times[(process_id, thread_id)][idB]["start"] + postB = function_times[(process_id, thread_id)][idB]["end"] + if preB >= postA: + break + if postB <= preA: + continue + return False + return True + + print("Start same level checking...") + same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} + valid_relations: Dict[Tuple[str, str], bool] = {} + + for (process_id, thread_id), _ in tqdm( + listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + ): + same_level_func[(process_id, thread_id)] = {} + for funcA, funcB in tqdm( + permutations(function_pool, 2), + ascii=True, + leave=True, + desc="Combinations Checked", + ): + if check_same_level(funcA, funcB, process_id, thread_id): + if funcA not in same_level_func[(process_id, thread_id)]: + same_level_func[(process_id, thread_id)][funcA] = [] + same_level_func[(process_id, thread_id)][funcA].append(funcB) + valid_relations[(funcA, funcB)] = True + print("End same level checking") + + # 3. Generating hypothesis + print("Start generating hypo...") + hypothesis_with_examples = { + (func_A, func_B): Hypothesis( + invariant=Invariant( + relation=FunctionLeadRelation, + params=[ + APIParam(func_A), + APIParam(func_B), + ], + precondition=None, + text_description=f"FunctionLeadRelation between {func_A} and {func_B}", + ), + positive_examples=ExampleList({EXP_GROUP_NAME}), + negative_examples=ExampleList({EXP_GROUP_NAME}), + ) + for (func_A, func_B), _ in valid_relations.items() + } + print("End generating hypo") + + # 4. Add positive and negative examples + print("Start adding examples...") + for (process_id, thread_id), events_list in tqdm( + listed_events.items(), ascii=True, leave=True, desc="Group" + ): + + for (func_A, func_B), _ in tqdm( + valid_relations.items(), + desc="Function Pair", + ): + + if func_A not in same_level_func[(process_id, thread_id)]: + continue + + if func_B not in same_level_func[(process_id, thread_id)][func_A]: + continue + + flag_A = None + pre_record_A = [] + + for event in events_list: + if event["type"] != "function_call (pre)": + continue + + if func_A == event["function"]: + if flag_A is None: + flag_A = event["time"] + pre_record_A = [event] + continue + + valid_relations[(func_A, func_B)] = False + neg = Example() + neg.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(neg) + pre_record_A = [event] + continue + + if func_B == event["function"]: + if flag_A is None: + continue + + pos = Example() + pos.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis_with_examples[ + (func_A, func_B) + ].positive_examples.add_example(pos) + + flag_A = None + pre_record_A = [] + + if flag_A is not None: + flag_A = None + neg = Example() + neg.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(neg) + pre_record_A = [] + + print("End adding examples") + + return list(hypothesis_with_examples.values()) + + @staticmethod + def collect_examples(trace, hypothesis): + """Generate examples for a hypothesis on trace.""" + inv = hypothesis.invariant + + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} + + # If the trace contains no function, return original hypothesis + func_names = trace.get_func_names() + if len(func_names) == 0: + print("No function calls found in the trace, skipping the collecting") + return + + function_pool = ( + [] + ) # Here function_pool only contains functions existing in given invariant + + invariant_length = len(inv.params) + for i in range(invariant_length): + func = inv.params[i] + assert isinstance( + func, APIParam + ), "Invariant parameters should be APIParam." + function_pool.append(func.api_full_name) + + function_pool = list(set(function_pool).intersection(func_names)) + + if len(function_pool) == 0: + print( + "No relevant function calls found in the trace, skipping the collecting" + ) + return + + print("Start fetching data for collecting...") + function_times, function_id_map, listed_events = get_func_data_per_PT( + trace, function_pool + ) + print("End fetching data for collecting...") + + def check_same_level(funcA: str, funcB: str, process_id, thread_id): + if funcA == funcB: + return False + + if funcA not in function_id_map[(process_id, thread_id)]: + return False + + if funcB not in function_id_map[(process_id, thread_id)]: + return True + + for idA in function_id_map[(process_id, thread_id)][funcA]: + for idB in function_id_map[(process_id, thread_id)][funcB]: + preA = function_times[(process_id, thread_id)][idA]["start"] + postA = function_times[(process_id, thread_id)][idA]["end"] + preB = function_times[(process_id, thread_id)][idB]["start"] + postB = function_times[(process_id, thread_id)][idB]["end"] + if preB >= postA: + break + if postB <= preA: + continue + return False + return True + + print("Starting collecting iteration...") + for i in range(invariant_length - 1): + func_A = inv.params[i] + func_B = inv.params[i + 1] + + assert isinstance(func_A, APIParam) and isinstance( + func_B, APIParam + ), "Invariant parameters should be string." + + for (process_id, thread_id), events_list in listed_events.items(): + funcA = func_A.api_full_name + funcB = func_B.api_full_name + + if not check_same_level(funcA, funcB, process_id, thread_id): + continue + + # check + flag_A = None + pre_record_A = None + for event in events_list: + + if event["type"] != "function_call (pre)": + continue + + if funcA == event["function"]: + if flag_A is None: + flag_A = event["time"] + pre_record_A = event + continue + + neg = Example() + neg.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis.negative_examples.add_example(neg) + pre_record_A = [event] + continue + + if funcB == event["function"]: + if flag_A is None: + continue + + pos = Example() + pos.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis.positive_examples.add_example(pos) + flag_A = None + pre_record_A = [] + + if flag_A is not None: + flag_A = None + neg = Example() + neg.add_group(EXP_GROUP_NAME, pre_record_A) + hypothesis_with_examples[ + (func_A, func_B) + ].negative_examples.add_example(neg) + pre_record_A = [] + + + @staticmethod + def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: + """Infer Invariants for the FunctionLeadrRelation.""" + + all_hypotheses = FunctionLeadRelation.generate_hypothesis(trace) + + # for hypothesis in all_hypotheses: + # FunctionLeadRelation.collect_examples(trace, hypothesis) + + print("Start precondition inference...") + failed_hypothesis = [] + for hypothesis in all_hypotheses.copy(): + preconditions = find_precondition( + hypothesis, trace + ) + if preconditions is not None: + hypothesis.invariant.precondition = ( + preconditions + ) + else: + failed_hypothesis.append( + FailedHypothesis(hypothesis) + ) + all_hypotheses.remove(hypothesis) + print("End precondition inference") + + if_merge = True + + if not if_merge: + return ( + list( + [hypo.invariant for hypo in all_hypotheses] + ), + failed_hypothesis, + ) + + # 6. Merge invariants + print("Start merging invariants...") + relation_pool: Dict[ + GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] + ] = {} + for hypothesis in all_hypotheses: + param0 = hypothesis.invariant.params[0] + param1 = hypothesis.invariant.params[1] + + assert(isinstance(param0, APIParam) and isinstance(param1, APIParam)) + + if ( + hypothesis.invariant.precondition + not in relation_pool + ): + relation_pool[ + hypothesis.invariant.precondition + ] = [] + relation_pool[ + hypothesis.invariant.precondition + ].append((param0, param1)) + + merged_relations: Dict[ + GroupedPreconditions | None, List[List[APIParam]] + ] = {} + + for key, values in tqdm(relation_pool.items(), desc="Merging Invariants"): + merged_relations[key] = merge_relations(values) + + merged_ininvariants = [] + + for key, merged_values in merged_relations.items(): + for merged_value in merged_values: + new_invariant = Invariant( + relation=FunctionLeadRelation, + params=[param for param in merged_value], + precondition=key, + text_description="Merged FunctionLeadRelation in Ordered List", + ) + merged_ininvariants.append(new_invariant) + print("End merging invariants") + + return merged_ininvariants, failed_hypothesis + + @staticmethod + def evaluate(value_group: list) -> bool: + """Given a group of values, should return a boolean value + indicating whether the relation holds or not. + + args: + value_group: list + A list of values to evaluate the relation on. The length of the list + should be equal to the number of variables in the relation. + """ + return True + + @staticmethod + def static_check_all( + trace: Trace, inv: Invariant, check_relation_first: bool + ) -> CheckerResult: + """Given a trace and an invariant, should return a boolean value + indicating whether the invariant holds on the trace. + + args: + trace: Trace + A trace to check the invariant on. + inv: Invariant + The invariant to check on the trace. + """ + + assert inv.precondition is not None, "Invariant should have a precondition." + + function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} + function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} + listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} + + inv_triggered = False + # If the trace contains no function, return vacuous true result + func_names = trace.get_func_names() + if len(func_names) == 0: + print("No function calls found in the trace, skipping the checking") + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + function_pool = ( + [] + ) # Here function_pool only contains functions existing in given invariant + + invariant_length = len(inv.params) + for i in range(invariant_length): + func = inv.params[i] + assert isinstance( + func, APIParam + ), "Invariant parameters should be APIParam." + function_pool.append(func.api_full_name) + + function_pool = list(set(function_pool).intersection(func_names)) + + # YUXUAN ASK: if function_pool is not stictly subset of func_names, should we directly return false? + + if len(function_pool) == 0: + print( + "No relevant function calls found in the trace, skipping the checking" + ) + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + print("Start fetching data for checking...") + function_times, function_id_map, listed_events = get_func_data_per_PT( + trace, function_pool + ) + print("End fetching data for checking...") + + def check_same_level(funcA: str, funcB: str, process_id, thread_id): + if funcA == funcB: + return False + + if funcA not in function_id_map[(process_id, thread_id)]: + return False + + if funcB not in function_id_map[(process_id, thread_id)]: + return True + + for idA in function_id_map[(process_id, thread_id)][funcA]: + for idB in function_id_map[(process_id, thread_id)][funcB]: + preA = function_times[(process_id, thread_id)][idA]["start"] + postA = function_times[(process_id, thread_id)][idA]["end"] + preB = function_times[(process_id, thread_id)][idB]["start"] + postB = function_times[(process_id, thread_id)][idB]["end"] + if preB >= postA: + break + if postB <= preA: + continue + return False + return True + + print("Starting checking iteration...") + for i in range(invariant_length - 1): + func_A = inv.params[i] + func_B = inv.params[i + 1] + + assert isinstance(func_A, APIParam) and isinstance( + func_B, APIParam + ), "Invariant parameters should be string." + + for (process_id, thread_id), events_list in listed_events.items(): + funcA = func_A.api_full_name + funcB = func_B.api_full_name + + if not check_same_level(funcA, funcB, process_id, thread_id): + continue + + # check + flag_A = None + pre_recordA = None + for event in events_list: + + if event["type"] != "function_call (pre)": + continue + + if funcA == event["function"]: + if flag_A is None: + flag_A = event["time"] + pre_recordA = event + continue + if inv.precondition.verify([events_list], EXP_GROUP_NAME): + inv_triggered = True + print( + "The relation " + + funcA + + " leads " + + funcB + + " is violated!\n" + ) + return CheckerResult( + trace=[pre_recordA, event], + invariant=inv, + check_passed=False, + triggered=True, + ) + if funcB == event["function"]: + flag_A = None + pre_recordA = None + + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=inv_triggered, + ) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/invariant/precondition.py b/mldaikon/invariant/precondition.py index e56f579e..bbd8f7f4 100644 --- a/mldaikon/invariant/precondition.py +++ b/mldaikon/invariant/precondition.py @@ -263,11 +263,13 @@ def _merge_clauses( def find_precondition( hypothesis: Hypothesis, - trace: Trace, - keys_to_skip: list[str] = [], + traces: list[Trace], ) -> GroupedPreconditions | None: """When None is returned, it means that we cannot find a precondition that is safe to use for the hypothesis.""" + keys_to_skip = hypothesis.invariant.relation.get_precondition_infer_keys_to_skip( + hypothesis + ) # postive examples and negative examples should have the same group names group_names = hypothesis.positive_examples.group_names # assert group_names == hypothesis.negative_examples.group_names @@ -294,7 +296,7 @@ def find_precondition( grouped_preconditions[group_name] = Preconditions( find_precondition_from_single_group( - positive_examples, negative_examples, trace, keys_to_skip + positive_examples, negative_examples, traces, keys_to_skip ) ) @@ -307,7 +309,7 @@ def find_precondition( # introducing this for inferring invariants related to context managers (e.g. input/output dtype should be the same when no autocast context is used) grouped_preconditions[group_name] = Preconditions( find_precondition_from_single_group( - negative_examples, positive_examples, trace, keys_to_skip + negative_examples, positive_examples, traces, keys_to_skip ), inverted=True, ) @@ -397,7 +399,7 @@ def _group_examples_by_stage( def find_precondition_from_single_group( positive_examples: list[list[dict]], negative_examples: list[list[dict]], - trace: Trace, + traces: list[Trace], keys_to_skip: list[str] = [], _pruned_clauses: set[PreconditionClause] = set(), _skip_pruning: bool = False, @@ -467,7 +469,7 @@ def find_precondition_from_single_group( if stage in grouped_negative_examples else [] ), - trace, + traces, keys_to_skip, _pruned_clauses, _skip_pruning, @@ -536,9 +538,17 @@ def find_precondition_from_single_group( thread_id = example[0]["thread_id"] if _current_depth == 0: - meta_vars = trace.get_meta_vars( - earliest_time, process_id=process_id, thread_id=thread_id - ) # HACK: get the context at the earliest time, ideally we should find the context that coverred the entire example duration + for trace in traces: + meta_vars = trace.get_meta_vars( # FIXME: add meta_vars to the examples prior find precondition + earliest_time, process_id=process_id, thread_id=thread_id + ) # HACK: get the context at the earliest time, ideally we should find the context that coverred the entire example duration + if meta_vars is not None: + break + if meta_vars is None: + logger.critical( + "Meta_vars not found for the positive examples, this should never happen but for the inference to continue we just skip the meta_vars update for this positive example" + ) + meta_vars = {} # update every trace with the meta_vars for key in meta_vars: @@ -585,9 +595,17 @@ def find_precondition_from_single_group( earliest_time = neg_example[0]["time"] process_id = neg_example[0]["process_id"] thread_id = neg_example[0]["thread_id"] - meta_vars = trace.get_meta_vars( - earliest_time, process_id=process_id, thread_id=thread_id - ) # HACK: get the context at the earliest time, ideally we should find the context that coverred the entire example duration + for trace in traces: + meta_vars = trace.get_meta_vars( + earliest_time, process_id=process_id, thread_id=thread_id + ) # HACK: get the context at the earliest time, ideally we should find the context that coverred the entire example duration + if meta_vars is not None: + break + if meta_vars is None: + logger.critical( + "Meta_vars not found for the negative examples, this should never happen but for the inference to continue we just skip the meta_vars update for this negative example" + ) + meta_vars = {} # update every trace with the meta_vars for key in meta_vars: @@ -731,7 +749,7 @@ def find_precondition_from_single_group( sub_preconditions = find_precondition_from_single_group( sub_positive_examples, passing_neg_exps, - trace=trace, + traces=traces, keys_to_skip=keys_to_skip, _pruned_clauses=_pruned_clauses, _skip_pruning=True, diff --git a/mldaikon/invariant/relation_pool.py b/mldaikon/invariant/relation_pool.py index 62319ab6..8dbe8d63 100644 --- a/mldaikon/invariant/relation_pool.py +++ b/mldaikon/invariant/relation_pool.py @@ -17,7 +17,7 @@ ConsistentOutputRelation, ConsistentInputOutputRelation, VarPeriodicChangeRelation, - FunctionCoverRelation, - FunctionLeadRelation, + # FunctionCoverRelation, + # FunctionLeadRelation, ThresholdRelation, ] diff --git a/mldaikon/invariant/symbolic_value.py b/mldaikon/invariant/symbolic_value.py index 27492ac5..30534e46 100644 --- a/mldaikon/invariant/symbolic_value.py +++ b/mldaikon/invariant/symbolic_value.py @@ -78,6 +78,8 @@ def generalize_values(values: list[type]) -> MD_NONE | type | str: """Given a list of values, should return a generalized value.""" assert values, "Values should not be empty" + values = [tuple(v) if isinstance(v, list) else v for v in values] # type: ignore + if len(set(values)) == 1: # no need to generalize return values[0] diff --git a/mldaikon/invariant/var_periodic_change_relation.py b/mldaikon/invariant/var_periodic_change_relation.py index dbf5e5bc..58ea37e9 100644 --- a/mldaikon/invariant/var_periodic_change_relation.py +++ b/mldaikon/invariant/var_periodic_change_relation.py @@ -42,7 +42,7 @@ def calculate_hypo_value(value) -> str: class VarPeriodicChangeRelation(Relation): @staticmethod - def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: + def generate_hypothesis(trace): """Infer Invariants for the VariableChangeRelation.""" logger = logging.getLogger(__name__) ## 1. Pre-scanning: Collecting variable instances and their values from the trace @@ -50,7 +50,7 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: var_insts = trace.get_var_insts() if len(var_insts) == 0: logger.warning("No variables found in the trace.") - return [], [] + return [] ## 2.Counting: count the number of each value of every variable attribute # TODO: record the intervals between occurrencess # TODO: improve time and memory efficiency @@ -142,11 +142,63 @@ def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: ) ) + return list(all_hypothesis.values()) + + @staticmethod + def collect_examples(trace, hypothesis): + var_group_name = "var" + var_insts = trace.get_var_insts() + + inv = hypothesis.invariant + + # a simple check to see if the variable changing to a specific value occur for the matched variables + assert ( + len(inv.params) == 1 + ), "The number of parameters should be 1 for VarPeriodicChangeRelation." + var_param = inv.params[0] + assert isinstance( + var_param, (VarNameParam, VarTypeParam) + ), "The parameter should be VarTypeParam for VarPeriodicChangeRelation." + + # for every variable instance, check whether it is interesting to the invariant + to_check_var_ids = [] + for var_id in var_insts: + if var_param.check_var_id_match(var_id): + to_check_var_ids.append(var_id) + + # for every to check variable instance, check the precondition match + # HACK: we use all the traces of the variable instances to check the precondition + for var_id in to_check_var_ids: + example_trace_records = [ + attr_inst.traces[-1] + for attr_inst in var_insts[var_id][var_param.attr_name] + ] + example = Example({var_group_name: example_trace_records}) + + # check if the value is periodically set to the hypo value + hypo_value = var_param.const_value + occur_count = 0 + for attr_inst in var_insts[var_id][var_param.attr_name]: + if calculate_hypo_value(attr_inst.value) == hypo_value: + occur_count += 1 + + if not count_num_justification(occur_count): + # add negative example + hypothesis.negative_examples.add_example(example) + else: + # add positive example + hypothesis.positive_examples.add_example(example) + + @staticmethod + def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: + logger = logging.getLogger(__name__) + + all_hypothesis = VarPeriodicChangeRelation.generate_hypothesis(trace) # 4. find preconditions valid_invariants = [] failed_hypothesis = [] for hypothesis in all_hypothesis.values(): - preconditions = find_precondition(hypothesis, trace) + preconditions = find_precondition(hypothesis, [trace]) if preconditions: hypothesis.invariant.precondition = preconditions valid_invariants.append(hypothesis.invariant) @@ -239,3 +291,7 @@ def static_check_all( return CheckerResult(example_trace_records, inv, False, True) return CheckerResult(None, inv, True, triggered) + + @staticmethod + def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]: + return [] diff --git a/mldaikon/trace/trace.py b/mldaikon/trace/trace.py index 6be3aa74..4b9e41c5 100644 --- a/mldaikon/trace/trace.py +++ b/mldaikon/trace/trace.py @@ -156,7 +156,7 @@ def query_active_context_managers( "This function should be implemented in the child class." ) - def get_meta_vars(self, time, process_id, thread_id) -> dict: + def get_meta_vars(self, time, process_id, thread_id) -> dict | None: """Get the meta variables at a specific time, process and thread.""" raise NotImplementedError( "This function should be implemented in the child class." diff --git a/mldaikon/trace/trace_pandas.py b/mldaikon/trace/trace_pandas.py index 7e5e0756..037b8298 100644 --- a/mldaikon/trace/trace_pandas.py +++ b/mldaikon/trace/trace_pandas.py @@ -43,6 +43,10 @@ def get_attr_name(col_name: str) -> str: return col_name[len(config.VAR_ATTR_PREFIX) :] +def safe_isnan(value: Any) -> bool: + return isinstance(value, float) and pd.isna(value) + + class TracePandas(Trace): def __init__(self, events, truncate_incomplete_func_calls=True): self.events = events @@ -400,7 +404,7 @@ def query_active_context_managers( def get_meta_vars( self, time: float, process_id: int, thread_id: int - ) -> dict[str, Any]: + ) -> dict[str, Any] | None: """Get the meta_vars a given time. Return value: @@ -409,6 +413,14 @@ def get_meta_vars( NOTE: CHANGING THE RETURN FORMAT WILL INTERFERE WITH THE PRECONDITION INFERENCE """ + + # if the process or thread id does not exist in the trace, return None + if ( + process_id not in self.get_process_ids() + or thread_id not in self.get_thread_ids() + ): + return None + meta_vars = {} active_context_managers = self.query_active_context_managers( time, process_id, thread_id @@ -652,9 +664,7 @@ def get_causally_related_vars(self, func_call_id) -> set[VarInstId]: assert ( related_func_call_pre_event["thread_id"] == thread_id ), "Related function call is on a different thread." - if isinstance( - related_func_call_pre_event["proxy_obj_names"], float - ) and pd.isna(related_func_call_pre_event["proxy_obj_names"]): + if safe_isnan(related_func_call_pre_event["proxy_obj_names"]): continue for var_name, var_type in related_func_call_pre_event["proxy_obj_names"]: if var_name == "" and var_type == "": @@ -779,7 +789,12 @@ def get_var_insts(self) -> dict[VarInstId, dict[str, list[AttrState]]]: attr_values = {} for _, state_change in state_changes.iterrows(): for col in state_change.index: - if pd.isna(state_change[col]): + + if "_ML_DAIKON" in col: + # IDs are only reserved for the use of DistinctArgumentRelation + continue + + if safe_isnan(state_change[col]): # skip NaN values as NaNs indicate that the attribute is not present in the state continue @@ -803,7 +818,12 @@ def get_var_insts(self) -> dict[VarInstId, dict[str, list[AttrState]]]: ) ] else: - if attr_values[attr_name][-1].value != state_change[col]: + if attr_values[attr_name][-1].value != state_change[ + col + ] and not ( + safe_isnan(attr_values[attr_name][-1].value) + and safe_isnan(state_change[col]) + ): attr_values[attr_name][-1].liveness.end_time = ( state_change["time"] ) @@ -869,14 +889,15 @@ def get_var_changes(self) -> list[VarChangeEvent]: new_state = var_insts[var_id][attr][i] # for debugging - import pandas as pd - - assert not pd.isna( - old_state.value - ), f"Old state is NaN for {var_id} {attr}" - assert not pd.isna( - new_state.value - ), f"New state is NaN for {var_id} {attr}" + # import pandas as pd + + # if not isinstance(old_state.value, Iterable) and not isinstance(new_state.value, Iterable) and not "_ML_DAIKON" in attr: + # assert not pd.isna( # AssertionError: Old state is NaN for VarInstId(process_id=374887, var_name='gc1.kernel', var_type='torch.nn.Parameter') _ML_DAIKON_grad_ID (why are those ids progatated to var states?) + # old_state.value + # ), f"Old state is NaN for {var_id} {attr}" + # assert not pd.isna( + # new_state.value + # ), f"New state is NaN for {var_id} {attr}" assert ( change_time is not None ), f"Start time not found for {var_id} {attr} {var_insts[var_id][attr][i].value}"