Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mldaikon/collect_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def get_default_output_folder(args: argparse.Namespace) -> str:
action="store_true",
help="Only instrument and dump the modified file",
)
parser.add_argument(
"--instr-descriptors",
action="store_true",
help="""Instrument functions that can only be accessed through descriptors,
Set this to true if you want to instrument built-in types like torch.Tensor,
at the cost of larger (5x) instrumentation overhead and more interference with the program""",
)
parser.add_argument(
"--profiling",
type=str, # @ziming-zh: why is this a string?
Expand Down Expand Up @@ -321,6 +328,7 @@ def get_default_output_folder(args: argparse.Namespace) -> str:
API_dump_stack_trace=args.API_dump_stack_trace,
cond_dump=args.cond_dump,
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
)

# call into the program runner
Expand Down
2 changes: 1 addition & 1 deletion mldaikon/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
# "torch._VariableFunctionsClass",
# "torch.get_default_dtype",
]

ANALYSIS_SKIP_FUNC_NAMES = [
"cuda.is_available",
"torch.get_default_dtype",
Expand Down Expand Up @@ -123,3 +122,4 @@
]

ENABLE_COND_DUMP = False
INSTR_DESCRIPTORS = False
14 changes: 8 additions & 6 deletions mldaikon/instrumentor/dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,17 @@ def convert_var_to_dict(var, include_tensor_data=True) -> dict:
if type(attr) in primitive_types:
result[attr_name] = attr

elif include_tensor_data and isinstance(attr, torch.Tensor):
result[attr_name] = dump_tensor(attr)
elif isinstance(attr, torch.Tensor):
result[f"_ML_DAIKON_{attr_name}_ID"] = id(attr)
if include_tensor_data:
result[attr_name] = dump_tensor(attr)

elif include_tensor_data and isinstance(attr, torch.nn.parameter.Parameter):
result[attr_name] = attr.__class__.__name__ + "(Parameter)"
result[attr_name] = dump_tensor(attr.data)
elif isinstance(attr, torch.nn.parameter.Parameter):
result[f"_ML_DAIKON_{attr_name}_ID"] = id(attr)
if include_tensor_data:
result[attr_name] = dump_tensor(attr.data)

elif include_tensor_data and isinstance(attr, torch.nn.Module):
result[attr_name] = attr.__class__.__name__ + "(nn.Module)"
# dump out all tensors inside the nn.Module
for name, param in attr.named_parameters():
result[attr_name] += f"\n{name}: {dump_tensor(param)}" # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions mldaikon/instrumentor/source_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def instrument_file(
API_dump_stack_trace: bool,
cond_dump: bool,
output_dir: str,
instr_descriptors: bool,
) -> str:
"""
Instruments the given file and returns the instrumented source code.
Expand Down Expand Up @@ -358,6 +359,7 @@ def instrument_file(
general_config_update = f"""
import mldaikon.config.config as general_config
general_config.ENABLE_COND_DUMP = {cond_dump}
general_config.INSTR_DESCRIPTORS = {instr_descriptors}
"""

if models_to_track:
Expand Down
34 changes: 26 additions & 8 deletions mldaikon/instrumentor/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.utils

import mldaikon.config.config as config # needed to allow for change of values after import
from mldaikon.config.config import (
INSTR_MODULES_TO_SKIP,
TRAIN_STEP_NAMES,
Expand Down Expand Up @@ -756,6 +757,18 @@ def _should_skip_module_or_cls(self, pymodule: object) -> str | None:
def _should_skip_instr_attr(self, attr_name: str, pymodule: object) -> str | None:
# 1. skip attrs with no objects (e.g. __abstractmethods__ and C extension functions)
attr = pymodule.__dict__.get(attr_name, None)
if attr is None:
# try getting it in case it is a descriptor (almost certainly will be)
try:
attr = getattr(pymodule, attr_name)
if not (
config.INSTR_DESCRIPTORS and "method_descriptor" in str(type(attr))
):
# print("TRIGGERED", attr_name)
attr = None
except Exception:
pass

if attr is None:
return "Skipping attribute as it is None"

Expand All @@ -778,17 +791,22 @@ def _should_skip_instr_attr(self, attr_name: str, pymodule: object) -> str | Non
):
return "Skipping magic functions"

attr_full_name = typename(attr)
# 4. Skip if the attribute is in INSTR_MODULES_TO_SKIP | MANUAL CONFIG
if typename(attr) in INSTR_MODULES_TO_SKIP:
if attr_full_name in INSTR_MODULES_TO_SKIP:
return "Skipping attribute as it is one of INSTR_MODULES_TO_SKIP"

# 5. Skip if the attribute is in modules_to_skip_prefix | MANUAL CONFIG
for modules_to_skip_prefix in INSTR_MODULES_TO_SKIP:
if typename(attr).startswith(modules_to_skip_prefix):
if attr_full_name.startswith(modules_to_skip_prefix):
return "Skipping attribute as it is in INSTR_MODULES_TO_SKIP"

# 6. Skip if the attribute does not belong to the target root module
if not typename(attr).startswith(self.root_module):
if not attr_full_name.startswith(self.root_module) and not (
config.INSTR_DESCRIPTORS
and ("method_descriptor" in attr_full_name or "Tensor" in attr_full_name)
):
# builtin methods in torch.Tensor's qualname does not start with torch for some reason
return "Skipping attribute as it does not belong to the root module"

return None
Expand Down Expand Up @@ -858,10 +876,7 @@ def _instrument_module(
if attr is None:
# try access this attribute to handle lazy loading
try:
_ = getattr(pymodule, attr_name)
attr = pymodule.__dict__.get(
attr_name
) # this attr should be loaded now
attr = getattr(pymodule, attr_name)
except Exception as e:
get_instrumentation_logger_for_process().debug(
f"Depth: {depth}, lazy loading failed for attribute: {attr_name}, Module: {target_name}: {e}"
Expand All @@ -882,7 +897,10 @@ def _instrument_module(

if isinstance(
attr, (types.FunctionType, types.BuiltinFunctionType, _instancemethod_t)
):
) or (
config.INSTR_DESCRIPTORS and "method_descriptor" in str(type(attr))
): # instrumented with potential accuracy issues as descriptor-controlled method access might change what to return based on given information, but is needed to get tensor method invocations
assert callable(attr), f"{attr} is not callable"
assert not (
recurse_into_sub_module and is_API_instrumented(attr)
), f"{attr} is already instrumented"
Expand Down
76 changes: 72 additions & 4 deletions mldaikon/invariant/DistinctArgumentRelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]):
return function_pool, listed_arguments


def get_event_list(trace: Trace, function_pool: Iterable[str]):
listed_events: List[dict[str, Any]]= []
# for all func_ids, get their corresponding events
for func_name in function_pool:
func_call_ids = trace.get_func_call_ids(func_name)
for func_call_id in func_call_ids:
event = trace.query_func_call_event(func_call_id)
listed_events.extend(
# [event.pre_record, event.post_record]
[event.pre_record]
)

# sort the listed_events
# for (process_id, thread_id), events_list in listed_events.items():
# listed_events[(process_id, thread_id)] = sorted(
# events_list, key=lambda x: x["time"]
# )

return listed_events

def compare_argument(value1, value2, IOU_criteria = True):
if type(value1) != type(value2):
return False
Expand All @@ -114,7 +134,7 @@ def compare_argument(value1, value2, IOU_criteria = True):
return abs(value1 - value2) < 1e-8
return value1 == value2

def is_arguments_list_distinct(args1: list, args2: list):
def is_arguments_list_same(args1: list, args2: list):
if len(args1) != len(args2):
return False
for index in range(len(args1)):
Expand Down Expand Up @@ -206,7 +226,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
for PT_pair1, PT_pair2 in combinations(records.keys(), 2):
for event1 in records[PT_pair1]:
for event2 in records[PT_pair2]:
if(not is_arguments_list_distinct(event1["args"], event2["args"])):
if(not is_arguments_list_same(event1["args"], event2["args"])):
flag = True
pos = Example()
pos.add_group(EXP_GROUP_NAME, [event1, event2])
Expand All @@ -218,7 +238,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:

for PT_pair in records.keys():
for event1, event2 in combinations(records[PT_pair], 2):
if(not is_arguments_list_distinct(event1["args"], event2["args"])):
if(not is_arguments_list_same(event1["args"], event2["args"])):
flag = True
pos = Example()
pos.add_group(EXP_GROUP_NAME, [event1, event2])
Expand Down Expand Up @@ -295,6 +315,34 @@ def static_check_all(
"""
assert inv.precondition is not None, "Invariant should have a precondition."

# 1. Pre-process all the events
print("Start preprocessing....")
listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = {}
function_pool: Set[Any] = set()
func= inv.params[0]

assert isinstance(
func, APIParam
), "Invariant parameters should be APIParam."

func_name = func.api_full_name
function_pool.add(func_name)

function_pool, listed_arguments = get_event_data_per_function_per_step(
trace, function_pool
)

if not function_pool:
return CheckerResult(
trace=None,
invariant=inv,
check_passed=True,
triggered=False,
)

events_list = get_event_list(trace, function_pool)
print("End preprocessing")

if not inv.precondition.verify([events_list], EXP_GROUP_NAME):
return CheckerResult(
trace=None,
Expand All @@ -303,7 +351,27 @@ def static_check_all(
triggered=False,
)

# TODO: checking process
for step, records in listed_arguments[func_name].items():
for PT_pair1, PT_pair2 in combinations(records.keys(), 2):
for event1 in records[PT_pair1]:
for event2 in records[PT_pair2]:
if(is_arguments_list_same(event1["args"], event2["args"])):
return CheckerResult(
trace=[event1, event2],
invariant=inv,
check_passed=True,
triggered=True,
)

for PT_pair in records.keys():
for event1, event2 in combinations(records[PT_pair], 2):
if(is_arguments_list_same(event1["args"], event2["args"])):
return CheckerResult(
trace=[event1, event2],
invariant=inv,
check_passed=True,
triggered=True,
)

return CheckerResult(
trace=None,
Expand Down
Loading