diff --git a/eval_scripts/perf_benchmark/overhead-micro/workload.py b/eval_scripts/perf_benchmark/overhead-micro/workload.py index d1171437..34ef22a8 100644 --- a/eval_scripts/perf_benchmark/overhead-micro/workload.py +++ b/eval_scripts/perf_benchmark/overhead-micro/workload.py @@ -9,7 +9,7 @@ torch, scan_proxy_in_args=True, use_full_instr=False, - funcs_of_inv_interest=[ + funcs_to_instr=[ "torch.cuda.is_available", "torch._VariableFunctionsClass.matmul", "torch._VariableFunctionsClass.bmm", diff --git a/mldaikon/collect_trace.py b/mldaikon/collect_trace.py index a0988e3e..ebcf1e9b 100644 --- a/mldaikon/collect_trace.py +++ b/mldaikon/collect_trace.py @@ -1,5 +1,6 @@ import argparse import datetime +import json import logging import os @@ -9,7 +10,22 @@ import mldaikon.instrumentor as instrumentor import mldaikon.proxy_wrapper.proxy_config as proxy_config import mldaikon.runner as runner -from mldaikon.invariant.base_cls import APIParam, Invariant, read_inv_file +from mldaikon.invariant.base_cls import ( + APIParam, + Arguments, + Invariant, + VarNameParam, + VarTypeParam, + read_inv_file, +) +from mldaikon.invariant.consistency_relation import ConsistencyRelation +from mldaikon.invariant.consistency_transient_vars import ( + ConsistentInputOutputRelation, + ConsistentOutputRelation, + ThresholdRelation, +) +from mldaikon.invariant.contain_relation import VAR_GROUP_NAME, APIContainRelation +from mldaikon.invariant.DistinctArgumentRelation import DistinctArgumentRelation def get_list_of_funcs_from_invariants(invariants: list[Invariant]) -> list[str]: @@ -24,6 +40,92 @@ def get_list_of_funcs_from_invariants(invariants: list[Invariant]) -> list[str]: return sorted(list(funcs)) +def get_per_func_instr_opts(invariants: list[Invariant]) -> dict[str, dict[str, bool]]: + """ + Get per function instrumentation options + """ + func_instr_opts = {} + for inv in invariants: + for param in inv.params: + if isinstance(param, APIParam): + if param.api_full_name not in func_instr_opts: + func_instr_opts[param.api_full_name] = { + "scan_proxy_in_args": False, + "dump_args": False, + "dump_ret": False, + } + # configure whether the arguments or return values needs to be dumped + if inv.relation in (ConsistentInputOutputRelation, ThresholdRelation): + func_instr_opts[param.api_full_name]["dump_args"] = True + func_instr_opts[param.api_full_name]["dump_ret"] = True + elif inv.relation == ConsistentOutputRelation: + func_instr_opts[param.api_full_name]["dump_ret"] = True + elif inv.relation == DistinctArgumentRelation: + func_instr_opts[param.api_full_name]["dump_args"] = True + + if inv.relation == APIContainRelation: + # if the argument of this param is not empty, then dump it + if isinstance(param.arguments, Arguments): + func_instr_opts[param.api_full_name]["dump_args"] = True + + if inv.relation == APIContainRelation: + assert isinstance(inv.params[0], APIParam) + assert inv.precondition is not None + if ( + isinstance(inv.params[1], (VarNameParam, VarTypeParam)) + and VAR_GROUP_NAME in inv.precondition.get_group_names() + and not inv.precondition.get_group(VAR_GROUP_NAME).is_unconditional() + ): + # if the APIContain invariant describes a variable, and the precondition is not unconditional on the variable, then scan the arguments of the function + func_instr_opts[inv.params[0].api_full_name][ + "scan_proxy_in_args" + ] = True + else: + func_instr_opts[inv.params[0].api_full_name][ + "scan_proxy_in_args" + ] = False + + return func_instr_opts + + +class InstrOpt: + def __init__(self, invariants: list[Invariant]): + self.funcs_instr_opts: dict[str, dict[str, bool]] = {} + self.model_tracker_style = None + + # determine model_tracker_style: + # if any of the invariants to be deployed is an APIContain invariant with a param describing a variable, then use proxy + # if any of the invariants to be deployed is a Consistency invariant, then use sampler (if not already set to proxy) + for inv in invariants: + if inv.relation == APIContainRelation: + for param in inv.params: + if isinstance(param, (VarNameParam, VarTypeParam)): + self.model_tracker_style = "proxy" + break + if inv.relation == ConsistencyRelation: + if self.model_tracker_style is None: + self.model_tracker_style = "sampler" + + if self.model_tracker_style == "proxy": + break + + # determine funcs_instr_opts + self.funcs_instr_opts = get_per_func_instr_opts(invariants) + + def to_json(self) -> str: + return json.dumps( + { + "funcs_instr_opts": self.funcs_instr_opts, + "model_tracker_style": self.model_tracker_style, + } + ) + + def from_json(self, instr_opt_json_str: str): + instr_opt_dict = yaml.safe_load(instr_opt_json_str) + self.funcs_instr_opts = instr_opt_dict["funcs_instr_opts"] + self.model_tracker_style = instr_opt_dict["model_tracker_style"] + + def dump_env(output_dir: str): with open(os.path.join(output_dir, "env_dump.txt"), "w") as f: f.write("Arguments:\n") @@ -188,13 +290,6 @@ def get_default_output_folder(args: argparse.Namespace) -> str: ) ## variable tracker configs - # parser.add_argument( - # "--proxy-module", - # type=str, - # default="", - # help="The module to be traced by the proxy wrapper", - # ) - parser.add_argument( "--models-to-track", nargs="*", @@ -277,12 +372,6 @@ def get_default_output_folder(args: argparse.Namespace) -> str: os.makedirs(output_dir) dump_env(output_dir) - # selective instrumentation if invariants are provided, only funcs_of_inv_interest will be instrumented with trace collection - funcs_of_inv_interest = None - if args.invariants is not None: - invariants = read_inv_file(args.invariants) - funcs_of_inv_interest = get_list_of_funcs_from_invariants(invariants) - # set up adjusted proxy_config proxy_basic_config: dict[str, int | bool | str] = {} for configs in [ @@ -330,20 +419,44 @@ def get_default_output_folder(args: argparse.Namespace) -> str: if not args.models_to_track or args.model_tracker_style != "proxy": scan_proxy_in_args = False - source_code = instrumentor.instrument_file( - path=args.pyscript, - modules_to_instr=args.modules_to_instr, - scan_proxy_in_args=scan_proxy_in_args, - use_full_instr=args.use_full_instr, - funcs_of_inv_interest=funcs_of_inv_interest, - models_to_track=args.models_to_track, - model_tracker_style=args.model_tracker_style, - adjusted_proxy_config=adjusted_proxy_config, # type: ignore - API_dump_stack_trace=args.API_dump_stack_trace, - cond_dump=args.cond_dump, - output_dir=output_dir, - instr_descriptors=args.instr_descriptors, - ) + if args.invariants: + # selective instrumentation if invariants are provided, only funcs_to_instr will be instrumented with trace collection + invariants = read_inv_file(args.invariants) + instr_opts = InstrOpt(invariants) + with open(os.path.join(output_dir, config.INSTR_OPTS_FILE), "w") as f: + f.write(instr_opts.to_json()) + models_to_track = ( + args.models_to_track if instr_opts.model_tracker_style else None + ) + source_code = instrumentor.instrument_file( + path=args.pyscript, + modules_to_instr=args.modules_to_instr, + scan_proxy_in_args=scan_proxy_in_args, + use_full_instr=args.use_full_instr, + funcs_to_instr=None, + models_to_track=models_to_track, + model_tracker_style=instr_opts.model_tracker_style, + adjusted_proxy_config=adjusted_proxy_config, # type: ignore + API_dump_stack_trace=args.API_dump_stack_trace, + cond_dump=args.cond_dump, + output_dir=output_dir, + instr_descriptors=args.instr_descriptors, + ) + else: + source_code = instrumentor.instrument_file( + path=args.pyscript, + modules_to_instr=args.modules_to_instr, + scan_proxy_in_args=scan_proxy_in_args, + use_full_instr=args.use_full_instr, + funcs_to_instr=None, + models_to_track=args.models_to_track, + model_tracker_style=args.model_tracker_style, + adjusted_proxy_config=adjusted_proxy_config, # type: ignore + API_dump_stack_trace=args.API_dump_stack_trace, + cond_dump=args.cond_dump, + output_dir=output_dir, + instr_descriptors=args.instr_descriptors, + ) if args.copy_all_files: # copy all files in the same directory as the pyscript to the output directory diff --git a/mldaikon/config/config.py b/mldaikon/config/config.py index ce172876..92beebf3 100644 --- a/mldaikon/config/config.py +++ b/mldaikon/config/config.py @@ -2,6 +2,7 @@ # tracer + instrumentor configs TMP_FILE_PREFIX = "_ml_daikon_" +INSTR_OPTS_FILE = "instr_opts.json" INSTR_MODULES_TO_INSTR = ["torch"] INSTR_MODULES_TO_SKIP = [ "torch.fx", @@ -151,7 +152,6 @@ ENABLE_COND_DUMP = False INSTR_DESCRIPTORS = False - ALL_STAGE_NAMES = { "init", "training", diff --git a/mldaikon/e2e_runner.py b/mldaikon/e2e_runner.py index f22b302a..47c6a6d0 100644 --- a/mldaikon/e2e_runner.py +++ b/mldaikon/e2e_runner.py @@ -161,7 +161,7 @@ def parse_input_env(input_env: str) -> dict[str, str]: "proxy_module": args.proxy_module, # "scan_proxy_in_args": False, # "allow_disable_dump": False, - # "funcs_of_inv_interest": None, + # "funcs_to_instr": None, "output_dir": output_dir, "API_LOG_DIR": api_log_dir, "profiling": str(args.profiling), diff --git a/mldaikon/instrumentor/source_file.py b/mldaikon/instrumentor/source_file.py index 0d05c7a0..dba3d6b7 100644 --- a/mldaikon/instrumentor/source_file.py +++ b/mldaikon/instrumentor/source_file.py @@ -17,7 +17,7 @@ def __init__( modules_to_instr: list[str], scan_proxy_in_args: bool, use_full_instr: bool, - funcs_of_inv_interest: list[str] | None, + funcs_to_instr: list[str] | None, API_dump_stack_trace: bool, cond_dump: bool, ): @@ -29,13 +29,13 @@ def __init__( self.modules_to_instr = modules_to_instr self.scan_proxy_in_args = scan_proxy_in_args self.use_full_instr = use_full_instr - self.funcs_of_inv_interest = funcs_of_inv_interest + self.funcs_to_instr = funcs_to_instr self.API_dump_stack_trace = API_dump_stack_trace self.cond_dump = cond_dump def get_instrument_node(self, module_name: str): return ast.parse( - f"from mldaikon.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_of_inv_interest={str(self.funcs_of_inv_interest)}, API_dump_stack_trace={self.API_dump_stack_trace}, cond_dump={self.cond_dump}).instrument()" + f"from mldaikon.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_to_instr={str(self.funcs_to_instr)}, API_dump_stack_trace={self.API_dump_stack_trace}, cond_dump={self.cond_dump}).instrument()" ).body def visit_Import(self, node): @@ -81,7 +81,7 @@ def instrument_library( modules_to_instr: list[str], scan_proxy_in_args: bool, use_full_instr: bool, - funcs_of_inv_interest: list[str] | None, + funcs_to_instr: list[str] | None, API_dump_stack_trace: bool, cond_dump: bool, ) -> str: @@ -103,7 +103,7 @@ def instrument_library( modules_to_instr, scan_proxy_in_args, use_full_instr, - funcs_of_inv_interest, + funcs_to_instr, API_dump_stack_trace, cond_dump, ) @@ -322,9 +322,9 @@ def instrument_file( modules_to_instr: list[str], scan_proxy_in_args: bool, use_full_instr: bool, - funcs_of_inv_interest: list[str] | None, + funcs_to_instr: list[str] | None, models_to_track: list[str] | None, - model_tracker_style: str, + model_tracker_style: str | None, adjusted_proxy_config: list[dict[str, int | bool | str]], API_dump_stack_trace: bool, cond_dump: bool, @@ -344,7 +344,7 @@ def instrument_file( modules_to_instr, scan_proxy_in_args, use_full_instr, - funcs_of_inv_interest, + funcs_to_instr, API_dump_stack_trace, cond_dump, ) diff --git a/mldaikon/instrumentor/tracer.py b/mldaikon/instrumentor/tracer.py index 49dd29b3..ff54f202 100644 --- a/mldaikon/instrumentor/tracer.py +++ b/mldaikon/instrumentor/tracer.py @@ -217,6 +217,8 @@ def global_wrapper( scan_proxy_in_args, dump_stack_trace, cond_dump, + dump_args, + dump_ret, *args, **kwargs, ): @@ -305,10 +307,10 @@ def find_proxy_in_args(args): pre_record["proxy_obj_names"].append( [proxy.__dict__["var_name"], type(proxy._obj).__name__] ) - - dict_args_kwargs = to_dict_args_kwargs(args, kwargs) - pre_record["args"] = dict_args_kwargs["args"] - pre_record["kwargs"] = dict_args_kwargs["kwargs"] + if dump_args: + dict_args_kwargs = to_dict_args_kwargs(args, kwargs) + pre_record["args"] = dict_args_kwargs["args"] + pre_record["kwargs"] = dict_args_kwargs["kwargs"] dump_trace_API(pre_record) if enable_C_level_observer and is_builtin: from mldaikon.proxy_wrapper.proxy_observer import ( @@ -356,8 +358,8 @@ def find_proxy_in_args(args): else None ) raise e - pre_record.pop("args") - pre_record.pop("kwargs") + pre_record.pop("args", None) + pre_record.pop("kwargs", None) post_record = ( pre_record.copy() ) # copy the pre_record (though we don't actually need to copy anything) @@ -425,8 +427,8 @@ def find_proxy_in_args(args): print(response_starting_indices) print(response_lengths) - - post_record["return_values"] = to_dict_return_value(result_to_dump) + if dump_ret: + post_record["return_values"] = to_dict_return_value(result_to_dump) dump_trace_API(post_record) EXIT_PERF_TIME = time.perf_counter() @@ -592,6 +594,8 @@ def wrapper( dump_stack_trace, cond_dump, disable_dump=False, + dump_args=True, + dump_ret=True, ): is_builtin = is_c_level_function(original_function) @@ -608,6 +612,8 @@ def wrapped(*args, **kwargs): scan_proxy_in_args, dump_stack_trace, cond_dump, + dump_args, + dump_ret, *args, **kwargs, ) @@ -732,7 +738,7 @@ def __init__( ), scan_proxy_in_args: bool, use_full_instr: bool, - funcs_of_inv_interest: Optional[list[str]] = None, + funcs_to_instr: Optional[list[str]] = None, API_dump_stack_trace: bool = False, cond_dump: bool = False, ): @@ -751,7 +757,7 @@ def __init__( use_full_instr (bool): Whether to dump trace for all APIs. If False, APIs in certain modules deemed to be not important (e.g. `jit` in `torch`) will not have trace being dumped. Refer to WRAP_WITHOUT_DUMP in config.py for the list of functions/modules that will have the dump disabled. - funcs_of_inv_interest (Optional[List[Callable]]): + funcs_to_instr (Optional[List[Callable]]): An optional list of functions that are of interest for invariant inference. If provided, all functions not in this list will be instrumented with dump disabled, and the functions in this list will be instrumented with dump enabled. NOTE: If this list is provided, use_full_str must be set to False. WRAP_WITHOUT_DUMP will be ignored. @@ -760,6 +766,11 @@ def __init__( cond_dump (bool): Whether to dump the trace conditionally. If True, the trace will only be dumped if meta_vars have changed since the last call of this particular function. This might cause additional overhead (cpu and memory) as the meta_vars will be compared with the previous call, and meta_vars will have to be cached in memory. + + Indirectly, at initialization, the instrumentor will also load the instr_opts.json file if it exists. + This file is automatically generated by the `collect_trace` script when `--invariants` is provided. + The user should not need to interact with this file directly. + """ self.instrumenting = True @@ -786,11 +797,12 @@ def __init__( self.target = target self.scan_proxy_in_args = scan_proxy_in_args self.use_full_instr = use_full_instr - self.funcs_of_inv_interest = funcs_of_inv_interest + self.funcs_to_instr = funcs_to_instr self.API_dump_stack_trace = API_dump_stack_trace self.cond_dump = cond_dump + self.instr_opts: None | dict[str, dict[str, dict[str, bool]]] = None - if self.funcs_of_inv_interest is not None and self.use_full_instr: + if self.funcs_to_instr is not None and self.use_full_instr: get_instrumentation_logger_for_process().fatal( "Invariants are provided but use_full_instr is True. Selective instrumentation cannot be done. Please remove the `--use-full-instr` flag or remove the invariants" ) @@ -798,9 +810,24 @@ def __init__( "Invariants are provided but use_full_instr is True. Selective instrumentation cannot be done. Please remove the `--use-full-instr` flag or remove the invariants" ) - if self.funcs_of_inv_interest is not None: + if self.funcs_to_instr is not None: get_instrumentation_logger_for_process().info( - f"Functions of interest for invariant inference: {self.funcs_of_inv_interest}" + f"Functions of interest for invariant inference: {self.funcs_to_instr}" + ) + + # discover if instr_opts.json is present + instr_opts_path = config.INSTR_OPTS_FILE + print("instr_opts_path: ", instr_opts_path) + get_instrumentation_logger_for_process().info( + f"Checking instr_opts at {instr_opts_path}" + ) + if os.path.exists(instr_opts_path): + print(f"Loading instr_opts from {instr_opts_path}") + with open(instr_opts_path, "r") as f: + instr_opts = json.load(f) + self.instr_opts = instr_opts + get_instrumentation_logger_for_process().info( + f"Loaded instr_opts: {json.dumps(instr_opts, indent=4)}" ) def instrument(self) -> int: @@ -876,20 +903,18 @@ def instrument(self) -> int: ) # do some simple checking for correctness: - # 1. if funcs_of_inv_interest is provided, then METRIC_INSTRUMENTED_FUNC_LIST["dump"] should be equal to funcs_of_inv_interest - if self.funcs_of_inv_interest is not None: + # 1. if funcs_to_instr is provided, then METRIC_INSTRUMENTED_FUNC_LIST["dump"] should be equal to funcs_to_instr + if self.funcs_to_instr is not None: # assert set(METRIC_INSTRUMENTED_FUNC_LIST["dump"]) == set( - # self.funcs_of_inv_interest - # ), f"METRIC_INSTRUMENTED_FUNC_LIST['dump'] != funcs_of_inv_interest, diff: {set(METRIC_INSTRUMENTED_FUNC_LIST['dump']) ^ set(self.funcs_of_inv_interest)}" + # self.funcs_to_instr + # ), f"METRIC_INSTRUMENTED_FUNC_LIST['dump'] != funcs_to_instr, diff: {set(METRIC_INSTRUMENTED_FUNC_LIST['dump']) ^ set(self.funcs_to_instr)}" assert set(METRIC_INSTRUMENTED_FUNC_LIST["dump"]).issubset( - set(self.funcs_of_inv_interest) - ), f"Actual functions being instrumented are not a subset of the functions required by the provided invariants, diff: {set(METRIC_INSTRUMENTED_FUNC_LIST['dump']) ^ set(self.funcs_of_inv_interest)}" + set(self.funcs_to_instr) + ), f"Actual functions being instrumented are not a subset of the functions required by the provided invariants, diff: {set(METRIC_INSTRUMENTED_FUNC_LIST['dump']) ^ set(self.funcs_to_instr)}" - if set(METRIC_INSTRUMENTED_FUNC_LIST["dump"]) != set( - self.funcs_of_inv_interest - ): + if set(METRIC_INSTRUMENTED_FUNC_LIST["dump"]) != set(self.funcs_to_instr): get_instrumentation_logger_for_process().warning( - f"Not all functions required by the provided invariants are instrumented (e.g. due to transfering ), some invariants might not be active at all, funcs not instrumented: {set(METRIC_INSTRUMENTED_FUNC_LIST['dump']) ^ set(self.funcs_of_inv_interest)}" + f"Not all functions required by the provided invariants are instrumented (e.g. due to transfering ), some invariants might not be active at all, funcs not instrumented: {set(METRIC_INSTRUMENTED_FUNC_LIST['dump']) ^ set(self.funcs_to_instr)}" ) # TODO: report a number of functions not instrumented and thus the invariants that will not be active IS_INSTRUMENTING = False @@ -970,15 +995,15 @@ def _should_skip_instr_attr(self, attr_name: str, pymodule: object) -> str | Non def should_disable_dump(self, attr) -> bool: """Check if the dump should be disabled for the attribute. If use_full_instr is True, then the dump will not be disabled. - If funcs_of_inv_interest is provided, then the dump will be disabled for all functions except the ones in funcs_of_inv_interest. + If funcs_to_instr is provided, then the dump will be disabled for all functions except the ones in funcs_to_instr. If the attribute is in WRAP_WITHOUT_DUMP, then the dump will be disabled. Otherwise, the dump will not be disabled. """ if self.use_full_instr: return False - if self.funcs_of_inv_interest is not None: - if typename(attr) in self.funcs_of_inv_interest: + if self.funcs_to_instr is not None: + if typename(attr) in self.funcs_to_instr: return False return True @@ -992,6 +1017,34 @@ def should_disable_dump(self, attr) -> bool: return True return False + def get_wrapped_function(self, func_obj: Callable) -> Callable | None: + """Get the wrapped function for the provided function object""" + if self.instr_opts is not None: + func_name = typename(func_obj) + if func_name not in self.instr_opts["funcs_instr_opts"]: + return None + + instr_opts = self.instr_opts["funcs_instr_opts"][func_name] + return wrapper( + func_obj, + is_bound_method=is_API_bound_method(func_obj), + scan_proxy_in_args=instr_opts["scan_proxy_in_args"], + disable_dump=self.should_disable_dump(func_obj), + dump_stack_trace=self.API_dump_stack_trace, + cond_dump=self.cond_dump, + dump_args=instr_opts["dump_args"], + dump_ret=instr_opts["dump_ret"], + ) + + return wrapper( + func_obj, + is_bound_method=is_API_bound_method(func_obj), + scan_proxy_in_args=self.scan_proxy_in_args, + disable_dump=self.should_disable_dump(func_obj), + dump_stack_trace=self.API_dump_stack_trace, + cond_dump=self.cond_dump, + ) + def _instrument_module( self, pymodule: types.ModuleType | type, @@ -1079,14 +1132,16 @@ def _instrument_module( ) attr = funcs_to_be_replaced[typename(attr)] - wrapped = wrapper( - attr, - is_bound_method=is_API_bound_method(attr), - scan_proxy_in_args=self.scan_proxy_in_args, - disable_dump=self.should_disable_dump(attr), - dump_stack_trace=self.API_dump_stack_trace, - cond_dump=self.cond_dump, - ) + wrapped = self.get_wrapped_function(attr) + if wrapped is None: + log_instrumentation_progress( + depth, + "Skipping function due to selective dumping", + attr, + attr_name, + pymodule, + ) + continue try: setattr(pymodule, attr_name, wrapped) except Exception as e: diff --git a/mldaikon/runner/runner.py b/mldaikon/runner/runner.py index 4088b61b..80c2795b 100644 --- a/mldaikon/runner/runner.py +++ b/mldaikon/runner/runner.py @@ -1,4 +1,3 @@ -import atexit import logging import os import signal @@ -72,7 +71,6 @@ def kill_running_process_on_except(typ, value, tb): def register_hook_closing_program(): signal.signal(signal.SIGTERM, handle_SIGTERM) signal.signal(signal.SIGINT, handle_SIGINT) - atexit.register(kill_running_process) sys.excepthook = kill_running_process_on_except