diff --git a/mldaikon/collect_trace.py b/mldaikon/collect_trace.py index e9a92fd9..715c414d 100644 --- a/mldaikon/collect_trace.py +++ b/mldaikon/collect_trace.py @@ -122,6 +122,13 @@ def get_default_output_folder(args: argparse.Namespace) -> str: action="store_true", help="Only instrument and dump the modified file", ) + parser.add_argument( + "--instr-descriptors", + action="store_true", + help="""Instrument functions that can only be accessed through descriptors, + Set this to true if you want to instrument built-in types like torch.Tensor, + at the cost of larger (5x) instrumentation overhead and more interference with the program""", + ) parser.add_argument( "--profiling", type=str, # @ziming-zh: why is this a string? @@ -321,6 +328,7 @@ def get_default_output_folder(args: argparse.Namespace) -> str: API_dump_stack_trace=args.API_dump_stack_trace, cond_dump=args.cond_dump, output_dir=output_dir, + instr_descriptors=args.instr_descriptors, ) # call into the program runner diff --git a/mldaikon/config/config.py b/mldaikon/config/config.py index 874d4750..cd6e92cb 100644 --- a/mldaikon/config/config.py +++ b/mldaikon/config/config.py @@ -27,7 +27,6 @@ # "torch._VariableFunctionsClass", # "torch.get_default_dtype", ] - ANALYSIS_SKIP_FUNC_NAMES = [ "cuda.is_available", "torch.get_default_dtype", @@ -123,3 +122,4 @@ ] ENABLE_COND_DUMP = False +INSTR_DESCRIPTORS = False diff --git a/mldaikon/instrumentor/dumper.py b/mldaikon/instrumentor/dumper.py index 0a3149c3..ff019b76 100644 --- a/mldaikon/instrumentor/dumper.py +++ b/mldaikon/instrumentor/dumper.py @@ -223,15 +223,17 @@ def convert_var_to_dict(var, include_tensor_data=True) -> dict: if type(attr) in primitive_types: result[attr_name] = attr - elif include_tensor_data and isinstance(attr, torch.Tensor): - result[attr_name] = dump_tensor(attr) + elif isinstance(attr, torch.Tensor): + result[f"_ML_DAIKON_{attr_name}_ID"] = id(attr) + if include_tensor_data: + result[attr_name] = dump_tensor(attr) - elif include_tensor_data and isinstance(attr, torch.nn.parameter.Parameter): - result[attr_name] = attr.__class__.__name__ + "(Parameter)" - result[attr_name] = dump_tensor(attr.data) + elif isinstance(attr, torch.nn.parameter.Parameter): + result[f"_ML_DAIKON_{attr_name}_ID"] = id(attr) + if include_tensor_data: + result[attr_name] = dump_tensor(attr.data) elif include_tensor_data and isinstance(attr, torch.nn.Module): - result[attr_name] = attr.__class__.__name__ + "(nn.Module)" # dump out all tensors inside the nn.Module for name, param in attr.named_parameters(): result[attr_name] += f"\n{name}: {dump_tensor(param)}" # type: ignore diff --git a/mldaikon/instrumentor/source_file.py b/mldaikon/instrumentor/source_file.py index 2bc96601..0d05c7a0 100644 --- a/mldaikon/instrumentor/source_file.py +++ b/mldaikon/instrumentor/source_file.py @@ -329,6 +329,7 @@ def instrument_file( API_dump_stack_trace: bool, cond_dump: bool, output_dir: str, + instr_descriptors: bool, ) -> str: """ Instruments the given file and returns the instrumented source code. @@ -358,6 +359,7 @@ def instrument_file( general_config_update = f""" import mldaikon.config.config as general_config general_config.ENABLE_COND_DUMP = {cond_dump} +general_config.INSTR_DESCRIPTORS = {instr_descriptors} """ if models_to_track: diff --git a/mldaikon/instrumentor/tracer.py b/mldaikon/instrumentor/tracer.py index 9fcab3d7..caccacf7 100644 --- a/mldaikon/instrumentor/tracer.py +++ b/mldaikon/instrumentor/tracer.py @@ -14,6 +14,7 @@ import torch import torch.utils +import mldaikon.config.config as config # needed to allow for change of values after import from mldaikon.config.config import ( INSTR_MODULES_TO_SKIP, TRAIN_STEP_NAMES, @@ -756,6 +757,18 @@ def _should_skip_module_or_cls(self, pymodule: object) -> str | None: def _should_skip_instr_attr(self, attr_name: str, pymodule: object) -> str | None: # 1. skip attrs with no objects (e.g. __abstractmethods__ and C extension functions) attr = pymodule.__dict__.get(attr_name, None) + if attr is None: + # try getting it in case it is a descriptor (almost certainly will be) + try: + attr = getattr(pymodule, attr_name) + if not ( + config.INSTR_DESCRIPTORS and "method_descriptor" in str(type(attr)) + ): + # print("TRIGGERED", attr_name) + attr = None + except Exception: + pass + if attr is None: return "Skipping attribute as it is None" @@ -778,17 +791,22 @@ def _should_skip_instr_attr(self, attr_name: str, pymodule: object) -> str | Non ): return "Skipping magic functions" + attr_full_name = typename(attr) # 4. Skip if the attribute is in INSTR_MODULES_TO_SKIP | MANUAL CONFIG - if typename(attr) in INSTR_MODULES_TO_SKIP: + if attr_full_name in INSTR_MODULES_TO_SKIP: return "Skipping attribute as it is one of INSTR_MODULES_TO_SKIP" # 5. Skip if the attribute is in modules_to_skip_prefix | MANUAL CONFIG for modules_to_skip_prefix in INSTR_MODULES_TO_SKIP: - if typename(attr).startswith(modules_to_skip_prefix): + if attr_full_name.startswith(modules_to_skip_prefix): return "Skipping attribute as it is in INSTR_MODULES_TO_SKIP" # 6. Skip if the attribute does not belong to the target root module - if not typename(attr).startswith(self.root_module): + if not attr_full_name.startswith(self.root_module) and not ( + config.INSTR_DESCRIPTORS + and ("method_descriptor" in attr_full_name or "Tensor" in attr_full_name) + ): + # builtin methods in torch.Tensor's qualname does not start with torch for some reason return "Skipping attribute as it does not belong to the root module" return None @@ -858,10 +876,7 @@ def _instrument_module( if attr is None: # try access this attribute to handle lazy loading try: - _ = getattr(pymodule, attr_name) - attr = pymodule.__dict__.get( - attr_name - ) # this attr should be loaded now + attr = getattr(pymodule, attr_name) except Exception as e: get_instrumentation_logger_for_process().debug( f"Depth: {depth}, lazy loading failed for attribute: {attr_name}, Module: {target_name}: {e}" @@ -882,7 +897,10 @@ def _instrument_module( if isinstance( attr, (types.FunctionType, types.BuiltinFunctionType, _instancemethod_t) - ): + ) or ( + config.INSTR_DESCRIPTORS and "method_descriptor" in str(type(attr)) + ): # instrumented with potential accuracy issues as descriptor-controlled method access might change what to return based on given information, but is needed to get tensor method invocations + assert callable(attr), f"{attr} is not callable" assert not ( recurse_into_sub_module and is_API_instrumented(attr) ), f"{attr} is already instrumented" diff --git a/mldaikon/invariant/DistinctArgumentRelation.py b/mldaikon/invariant/DistinctArgumentRelation.py index a9a12f92..26ee667e 100644 --- a/mldaikon/invariant/DistinctArgumentRelation.py +++ b/mldaikon/invariant/DistinctArgumentRelation.py @@ -89,6 +89,26 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]): return function_pool, listed_arguments +def get_event_list(trace: Trace, function_pool: Iterable[str]): + listed_events: List[dict[str, Any]]= [] + # for all func_ids, get their corresponding events + for func_name in function_pool: + func_call_ids = trace.get_func_call_ids(func_name) + for func_call_id in func_call_ids: + event = trace.query_func_call_event(func_call_id) + listed_events.extend( + # [event.pre_record, event.post_record] + [event.pre_record] + ) + + # sort the listed_events + # for (process_id, thread_id), events_list in listed_events.items(): + # listed_events[(process_id, thread_id)] = sorted( + # events_list, key=lambda x: x["time"] + # ) + + return listed_events + def compare_argument(value1, value2, IOU_criteria = True): if type(value1) != type(value2): return False @@ -114,7 +134,7 @@ def compare_argument(value1, value2, IOU_criteria = True): return abs(value1 - value2) < 1e-8 return value1 == value2 -def is_arguments_list_distinct(args1: list, args2: list): +def is_arguments_list_same(args1: list, args2: list): if len(args1) != len(args2): return False for index in range(len(args1)): @@ -206,7 +226,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: for PT_pair1, PT_pair2 in combinations(records.keys(), 2): for event1 in records[PT_pair1]: for event2 in records[PT_pair2]: - if(not is_arguments_list_distinct(event1["args"], event2["args"])): + if(not is_arguments_list_same(event1["args"], event2["args"])): flag = True pos = Example() pos.add_group(EXP_GROUP_NAME, [event1, event2]) @@ -218,7 +238,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: for PT_pair in records.keys(): for event1, event2 in combinations(records[PT_pair], 2): - if(not is_arguments_list_distinct(event1["args"], event2["args"])): + if(not is_arguments_list_same(event1["args"], event2["args"])): flag = True pos = Example() pos.add_group(EXP_GROUP_NAME, [event1, event2]) @@ -295,6 +315,34 @@ def static_check_all( """ assert inv.precondition is not None, "Invariant should have a precondition." + # 1. Pre-process all the events + print("Start preprocessing....") + listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = {} + function_pool: Set[Any] = set() + func= inv.params[0] + + assert isinstance( + func, APIParam + ), "Invariant parameters should be APIParam." + + func_name = func.api_full_name + function_pool.add(func_name) + + function_pool, listed_arguments = get_event_data_per_function_per_step( + trace, function_pool + ) + + if not function_pool: + return CheckerResult( + trace=None, + invariant=inv, + check_passed=True, + triggered=False, + ) + + events_list = get_event_list(trace, function_pool) + print("End preprocessing") + if not inv.precondition.verify([events_list], EXP_GROUP_NAME): return CheckerResult( trace=None, @@ -303,7 +351,27 @@ def static_check_all( triggered=False, ) - # TODO: checking process + for step, records in listed_arguments[func_name].items(): + for PT_pair1, PT_pair2 in combinations(records.keys(), 2): + for event1 in records[PT_pair1]: + for event2 in records[PT_pair2]: + if(is_arguments_list_same(event1["args"], event2["args"])): + return CheckerResult( + trace=[event1, event2], + invariant=inv, + check_passed=True, + triggered=True, + ) + + for PT_pair in records.keys(): + for event1, event2 in combinations(records[PT_pair], 2): + if(is_arguments_list_same(event1["args"], event2["args"])): + return CheckerResult( + trace=[event1, event2], + invariant=inv, + check_passed=True, + triggered=True, + ) return CheckerResult( trace=None,