Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mldaikon/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,15 @@

ENABLE_COND_DUMP = False
INSTR_DESCRIPTORS = False


ALL_STAGE_NAMES = {
"init",
"training",
"evaluation",
"inference",
"testing",
"checkpointing",
"preprocessing",
"postprocessing",
}
16 changes: 3 additions & 13 deletions mldaikon/developer/annotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import mldaikon.instrumentor.tracer as tracer
from mldaikon.config.config import ALL_STAGE_NAMES
from mldaikon.instrumentor import meta_vars


Expand All @@ -11,20 +12,9 @@ def annotate_stage(stage_name: str):
Note that it is your responsibility to make sure this function is called on all threads that potentially can generate invariant candidates.
"""

valid_stage_names = [
"init",
"training",
"evaluation",
"inference",
"testing",
"checkpointing",
"preprocessing",
"postprocessing",
]

assert (
stage_name in valid_stage_names
), f"Invalid stage name: {stage_name}, valid ones are {valid_stage_names}"
stage_name in ALL_STAGE_NAMES
), f"Invalid stage name: {stage_name}, valid ones are {ALL_STAGE_NAMES}"

meta_vars["stage"] = stage_name

Expand Down
101 changes: 96 additions & 5 deletions mldaikon/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
import datetime
import json
import logging
import os
import random
import time

import mldaikon.config.config as config
from mldaikon.invariant.base_cls import FailedHypothesis, Invariant, Relation
from mldaikon.invariant.base_cls import (
FailedHypothesis,
Hypothesis,
Invariant,
Relation,
)
from mldaikon.invariant.precondition import find_precondition
from mldaikon.invariant.relation_pool import relation_pool
from mldaikon.trace import MDNONEJSONEncoder, select_trace_implementation
from mldaikon.utils import register_custom_excepthook
Expand Down Expand Up @@ -46,6 +53,62 @@ def infer(self, disabled_relations: list[Relation]):
)
return all_invs, all_failed_hypos

def infer_multi_trace(self, disabled_relations: list[Relation]):
hypotheses = self.generate_hypothesis(disabled_relations)
self.collect_examples(hypotheses)
invariants, failed_hypos = self.infer_precondition(hypotheses)
return invariants, failed_hypos

def generate_hypothesis(
self, disabled_relations: list[Relation]
) -> list[list[Hypothesis]]:
hypotheses = []
for trace in self.traces:
for relation in relation_pool:
if disabled_relations is not None and relation in disabled_relations:
logger.info(
f"Skipping relation {relation.__name__} as it is disabled"
)
continue
logger.info(f"Generating hypotheses for relation: {relation.__name__}")
hypotheses.append(relation.generate_hypothesis(trace))
logger.info(
f"Found {len(hypotheses[-1])} hypotheses for relation: {relation.__name__}"
)
return hypotheses

def collect_examples(self, hypotheses: list[list[Hypothesis]]):
for i, trace in enumerate(self.traces):
for j, hypothesis in enumerate(hypotheses[i]):
if j == i:
# already collected examples for this hypothesis on the same trace that generated it
continue
hypothesis.invariant.relation.collect_examples(trace, hypothesis)

def infer_precondition(self, hypotheses: list[list[Hypothesis]]):
all_hypotheses: list[Hypothesis] = []
for trace_hypotheses in hypotheses:
for hypothesis in trace_hypotheses:
all_hypotheses.append(hypothesis)

invariants = []
failed_hypos = []
for hypothesis in all_hypotheses:
hypothesis.invariant.num_positive_examples = len(
hypothesis.positive_examples
)
hypothesis.invariant.num_negative_examples = len(
hypothesis.negative_examples
)
precondition = find_precondition(hypothesis, self.traces)
if precondition is None:
failed_hypos.append(FailedHypothesis(hypothesis))
else:
hypothesis.invariant.precondition = precondition
invariants.append(hypothesis.invariant)

return invariants, failed_hypos


def save_invs(invs: list[Invariant], output_file: str):
with open(output_file, "w") as f:
Expand All @@ -69,9 +132,15 @@ def save_failed_hypos(failed_hypos: list[FailedHypothesis], output_file: str):
"-t",
"--traces",
nargs="+",
required=True,
required=False,
help="Traces files to infer invariants on",
)
parser.add_argument(
"-f",
"--trace-folders",
nargs="+",
help='Folders containing traces files to infer invariants on. Trace files should start with "trace_" or "proxy_log.json"',
)
parser.add_argument(
"-d",
"--debug",
Expand Down Expand Up @@ -110,6 +179,14 @@ def save_failed_hypos(failed_hypos: list[FailedHypothesis], output_file: str):
)
args = parser.parse_args()

# check if either traces or trace folders are provided
if args.traces is None and args.trace_folders is None:
# print help message if neither traces nor trace folders are provided
parser.print_help()
parser.error(
"Please provide either traces or trace folders to infer invariants"
)

Trace, read_trace_file = select_trace_implementation(args.backend)

if args.debug:
Expand Down Expand Up @@ -138,14 +215,28 @@ def save_failed_hypos(failed_hypos: list[FailedHypothesis], output_file: str):
config.PRECOND_SAMPLING_THRESHOLD = args.precond_sampling_threshold

time_start = time.time()
logger.info("Reading traces from %s", "\n".join(args.traces))
traces = [read_trace_file(args.traces)]

traces = []
if args.traces is not None:
logger.info("Reading traces from %s", "\n".join(args.traces))
traces.append(read_trace_file(args.traces))
if args.trace_folders is not None:
for trace_folder in args.trace_folders:
# file discovery
trace_files = [
f"{trace_folder}/{file}"
for file in os.listdir(trace_folder)
if file.startswith("trace_") or file.startswith("proxy_log.json")
]
logger.info("Reading traces from %s", "\n".join(trace_files))
traces.append(read_trace_file(trace_files))

time_end = time.time()
logger.info(f"Traces read successfully in {time_end - time_start} seconds.")

time_start = time.time()
engine = InferEngine(traces)
invs, failed_hypos = engine.infer(disabled_relations=disabled_relations)
invs, failed_hypos = engine.infer_multi_trace(disabled_relations=disabled_relations)
time_end = time.time()
logger.info(f"Inference completed in {time_end - time_start} seconds.")

Expand Down
10 changes: 6 additions & 4 deletions mldaikon/instrumentor/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def find_proxy_in_args(args):
"func_call_id": func_call_id,
"thread_id": thread_id,
"process_id": process_id,
"meta_vars": get_meta_vars(),
"meta_vars": pre_meta_vars,
"type": TraceLineType.FUNC_CALL_POST_EXCEPTION,
"function": func_name,
# "args": [f"{arg}" for arg in args],
Expand All @@ -346,7 +346,7 @@ def find_proxy_in_args(args):
pre_record.copy()
) # copy the pre_record (though we don't actually need to copy anything)
post_record["type"] = TraceLineType.FUNC_CALL_POST
post_record["meta_vars"] = get_meta_vars()
post_record["meta_vars"] = pre_meta_vars

result_to_dump = result

Expand Down Expand Up @@ -1011,6 +1011,7 @@ def __init__(self, var, dump_tensor_hash: bool = True):
self.param_versions = {} # type: ignore
timestamp = datetime.datetime.now().timestamp()

curr_meta_vars = get_meta_vars()
for param in self._get_state_copy():
attributes = param["attributes"]
if dump_tensor_hash:
Expand All @@ -1026,7 +1027,7 @@ def __init__(self, var, dump_tensor_hash: bool = True):
"var_type": param["type"],
"process_id": os.getpid(),
"thread_id": threading.current_thread().ident,
"meta_vars": get_meta_vars(),
"meta_vars": curr_meta_vars,
"type": TraceLineType.STATE_CHANGE,
"attributes": attributes,
"time": timestamp,
Expand Down Expand Up @@ -1054,14 +1055,15 @@ def dump_sample(self):

timestamp = datetime.datetime.now().timestamp()

curr_meta_vars = get_meta_vars()
for param in self._get_state_copy():
dump_trace_VAR(
{
"var_name": param["name"],
"var_type": param["type"], # FIXME: hardcoding the type for now
"process_id": os.getpid(),
"thread_id": threading.current_thread().ident,
"meta_vars": get_meta_vars(),
"meta_vars": curr_meta_vars,
"type": TraceLineType.STATE_CHANGE,
"attributes": param["attributes"],
"time": timestamp,
Expand Down
Loading
Loading