Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eval_scripts/perf_benchmark/overhead-micro/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
torch,
scan_proxy_in_args=True,
use_full_instr=False,
funcs_of_inv_interest=[
funcs_to_instr=[
"torch.cuda.is_available",
"torch._VariableFunctionsClass.matmul",
"torch._VariableFunctionsClass.bmm",
Expand Down
169 changes: 141 additions & 28 deletions mldaikon/collect_trace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import datetime
import json
import logging
import os

Expand All @@ -9,7 +10,22 @@
import mldaikon.instrumentor as instrumentor
import mldaikon.proxy_wrapper.proxy_config as proxy_config
import mldaikon.runner as runner
from mldaikon.invariant.base_cls import APIParam, Invariant, read_inv_file
from mldaikon.invariant.base_cls import (
APIParam,
Arguments,
Invariant,
VarNameParam,
VarTypeParam,
read_inv_file,
)
from mldaikon.invariant.consistency_relation import ConsistencyRelation
from mldaikon.invariant.consistency_transient_vars import (
ConsistentInputOutputRelation,
ConsistentOutputRelation,
ThresholdRelation,
)
from mldaikon.invariant.contain_relation import VAR_GROUP_NAME, APIContainRelation
from mldaikon.invariant.DistinctArgumentRelation import DistinctArgumentRelation


def get_list_of_funcs_from_invariants(invariants: list[Invariant]) -> list[str]:
Expand All @@ -24,6 +40,92 @@ def get_list_of_funcs_from_invariants(invariants: list[Invariant]) -> list[str]:
return sorted(list(funcs))


def get_per_func_instr_opts(invariants: list[Invariant]) -> dict[str, dict[str, bool]]:
"""
Get per function instrumentation options
"""
func_instr_opts = {}
for inv in invariants:
for param in inv.params:
if isinstance(param, APIParam):
if param.api_full_name not in func_instr_opts:
func_instr_opts[param.api_full_name] = {
"scan_proxy_in_args": False,
"dump_args": False,
"dump_ret": False,
}
# configure whether the arguments or return values needs to be dumped
if inv.relation in (ConsistentInputOutputRelation, ThresholdRelation):
func_instr_opts[param.api_full_name]["dump_args"] = True
func_instr_opts[param.api_full_name]["dump_ret"] = True
elif inv.relation == ConsistentOutputRelation:
func_instr_opts[param.api_full_name]["dump_ret"] = True
elif inv.relation == DistinctArgumentRelation:
func_instr_opts[param.api_full_name]["dump_args"] = True

if inv.relation == APIContainRelation:
# if the argument of this param is not empty, then dump it
if isinstance(param.arguments, Arguments):
func_instr_opts[param.api_full_name]["dump_args"] = True

if inv.relation == APIContainRelation:
assert isinstance(inv.params[0], APIParam)
assert inv.precondition is not None
if (
isinstance(inv.params[1], (VarNameParam, VarTypeParam))
and VAR_GROUP_NAME in inv.precondition.get_group_names()
and not inv.precondition.get_group(VAR_GROUP_NAME).is_unconditional()
):
# if the APIContain invariant describes a variable, and the precondition is not unconditional on the variable, then scan the arguments of the function
func_instr_opts[inv.params[0].api_full_name][
"scan_proxy_in_args"
] = True
else:
func_instr_opts[inv.params[0].api_full_name][
"scan_proxy_in_args"
] = False

return func_instr_opts


class InstrOpt:
def __init__(self, invariants: list[Invariant]):
self.funcs_instr_opts: dict[str, dict[str, bool]] = {}
self.model_tracker_style = None

# determine model_tracker_style:
# if any of the invariants to be deployed is an APIContain invariant with a param describing a variable, then use proxy
# if any of the invariants to be deployed is a Consistency invariant, then use sampler (if not already set to proxy)
for inv in invariants:
if inv.relation == APIContainRelation:
for param in inv.params:
if isinstance(param, (VarNameParam, VarTypeParam)):
self.model_tracker_style = "proxy"
break
if inv.relation == ConsistencyRelation:
if self.model_tracker_style is None:
self.model_tracker_style = "sampler"

if self.model_tracker_style == "proxy":
break

# determine funcs_instr_opts
self.funcs_instr_opts = get_per_func_instr_opts(invariants)

def to_json(self) -> str:
return json.dumps(
{
"funcs_instr_opts": self.funcs_instr_opts,
"model_tracker_style": self.model_tracker_style,
}
)

def from_json(self, instr_opt_json_str: str):
instr_opt_dict = yaml.safe_load(instr_opt_json_str)
self.funcs_instr_opts = instr_opt_dict["funcs_instr_opts"]
self.model_tracker_style = instr_opt_dict["model_tracker_style"]


def dump_env(output_dir: str):
with open(os.path.join(output_dir, "env_dump.txt"), "w") as f:
f.write("Arguments:\n")
Expand Down Expand Up @@ -188,13 +290,6 @@ def get_default_output_folder(args: argparse.Namespace) -> str:
)

## variable tracker configs
# parser.add_argument(
# "--proxy-module",
# type=str,
# default="",
# help="The module to be traced by the proxy wrapper",
# )

parser.add_argument(
"--models-to-track",
nargs="*",
Expand Down Expand Up @@ -277,12 +372,6 @@ def get_default_output_folder(args: argparse.Namespace) -> str:
os.makedirs(output_dir)
dump_env(output_dir)

# selective instrumentation if invariants are provided, only funcs_of_inv_interest will be instrumented with trace collection
funcs_of_inv_interest = None
if args.invariants is not None:
invariants = read_inv_file(args.invariants)
funcs_of_inv_interest = get_list_of_funcs_from_invariants(invariants)

# set up adjusted proxy_config
proxy_basic_config: dict[str, int | bool | str] = {}
for configs in [
Expand Down Expand Up @@ -330,20 +419,44 @@ def get_default_output_folder(args: argparse.Namespace) -> str:
if not args.models_to_track or args.model_tracker_style != "proxy":
scan_proxy_in_args = False

source_code = instrumentor.instrument_file(
path=args.pyscript,
modules_to_instr=args.modules_to_instr,
scan_proxy_in_args=scan_proxy_in_args,
use_full_instr=args.use_full_instr,
funcs_of_inv_interest=funcs_of_inv_interest,
models_to_track=args.models_to_track,
model_tracker_style=args.model_tracker_style,
adjusted_proxy_config=adjusted_proxy_config, # type: ignore
API_dump_stack_trace=args.API_dump_stack_trace,
cond_dump=args.cond_dump,
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
)
if args.invariants:
# selective instrumentation if invariants are provided, only funcs_to_instr will be instrumented with trace collection
invariants = read_inv_file(args.invariants)
instr_opts = InstrOpt(invariants)
with open(os.path.join(output_dir, config.INSTR_OPTS_FILE), "w") as f:
f.write(instr_opts.to_json())
models_to_track = (
args.models_to_track if instr_opts.model_tracker_style else None
)
source_code = instrumentor.instrument_file(
path=args.pyscript,
modules_to_instr=args.modules_to_instr,
scan_proxy_in_args=scan_proxy_in_args,
use_full_instr=args.use_full_instr,
funcs_to_instr=None,
models_to_track=models_to_track,
model_tracker_style=instr_opts.model_tracker_style,
adjusted_proxy_config=adjusted_proxy_config, # type: ignore
API_dump_stack_trace=args.API_dump_stack_trace,
cond_dump=args.cond_dump,
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
)
else:
source_code = instrumentor.instrument_file(
path=args.pyscript,
modules_to_instr=args.modules_to_instr,
scan_proxy_in_args=scan_proxy_in_args,
use_full_instr=args.use_full_instr,
funcs_to_instr=None,
models_to_track=args.models_to_track,
model_tracker_style=args.model_tracker_style,
adjusted_proxy_config=adjusted_proxy_config, # type: ignore
API_dump_stack_trace=args.API_dump_stack_trace,
cond_dump=args.cond_dump,
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
)

if args.copy_all_files:
# copy all files in the same directory as the pyscript to the output directory
Expand Down
2 changes: 1 addition & 1 deletion mldaikon/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# tracer + instrumentor configs
TMP_FILE_PREFIX = "_ml_daikon_"
INSTR_OPTS_FILE = "instr_opts.json"
INSTR_MODULES_TO_INSTR = ["torch"]
INSTR_MODULES_TO_SKIP = [
"torch.fx",
Expand Down Expand Up @@ -151,7 +152,6 @@
ENABLE_COND_DUMP = False
INSTR_DESCRIPTORS = False


ALL_STAGE_NAMES = {
"init",
"training",
Expand Down
2 changes: 1 addition & 1 deletion mldaikon/e2e_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def parse_input_env(input_env: str) -> dict[str, str]:
"proxy_module": args.proxy_module,
# "scan_proxy_in_args": False,
# "allow_disable_dump": False,
# "funcs_of_inv_interest": None,
# "funcs_to_instr": None,
"output_dir": output_dir,
"API_LOG_DIR": api_log_dir,
"profiling": str(args.profiling),
Expand Down
16 changes: 8 additions & 8 deletions mldaikon/instrumentor/source_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
modules_to_instr: list[str],
scan_proxy_in_args: bool,
use_full_instr: bool,
funcs_of_inv_interest: list[str] | None,
funcs_to_instr: list[str] | None,
API_dump_stack_trace: bool,
cond_dump: bool,
):
Expand All @@ -29,13 +29,13 @@ def __init__(
self.modules_to_instr = modules_to_instr
self.scan_proxy_in_args = scan_proxy_in_args
self.use_full_instr = use_full_instr
self.funcs_of_inv_interest = funcs_of_inv_interest
self.funcs_to_instr = funcs_to_instr
self.API_dump_stack_trace = API_dump_stack_trace
self.cond_dump = cond_dump

def get_instrument_node(self, module_name: str):
return ast.parse(
f"from mldaikon.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_of_inv_interest={str(self.funcs_of_inv_interest)}, API_dump_stack_trace={self.API_dump_stack_trace}, cond_dump={self.cond_dump}).instrument()"
f"from mldaikon.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_to_instr={str(self.funcs_to_instr)}, API_dump_stack_trace={self.API_dump_stack_trace}, cond_dump={self.cond_dump}).instrument()"
).body

def visit_Import(self, node):
Expand Down Expand Up @@ -81,7 +81,7 @@ def instrument_library(
modules_to_instr: list[str],
scan_proxy_in_args: bool,
use_full_instr: bool,
funcs_of_inv_interest: list[str] | None,
funcs_to_instr: list[str] | None,
API_dump_stack_trace: bool,
cond_dump: bool,
) -> str:
Expand All @@ -103,7 +103,7 @@ def instrument_library(
modules_to_instr,
scan_proxy_in_args,
use_full_instr,
funcs_of_inv_interest,
funcs_to_instr,
API_dump_stack_trace,
cond_dump,
)
Expand Down Expand Up @@ -322,9 +322,9 @@ def instrument_file(
modules_to_instr: list[str],
scan_proxy_in_args: bool,
use_full_instr: bool,
funcs_of_inv_interest: list[str] | None,
funcs_to_instr: list[str] | None,
models_to_track: list[str] | None,
model_tracker_style: str,
model_tracker_style: str | None,
adjusted_proxy_config: list[dict[str, int | bool | str]],
API_dump_stack_trace: bool,
cond_dump: bool,
Expand All @@ -344,7 +344,7 @@ def instrument_file(
modules_to_instr,
scan_proxy_in_args,
use_full_instr,
funcs_of_inv_interest,
funcs_to_instr,
API_dump_stack_trace,
cond_dump,
)
Expand Down
Loading