Skip to content

Commit ffebd2a

Browse files
committed
add: multi-trace inference single-cpu version
1 parent 9e1bca3 commit ffebd2a

12 files changed

+287
-179
lines changed

mldaikon/infer_engine.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import time
88

99
import mldaikon.config.config as config
10-
from mldaikon.invariant.base_cls import FailedHypothesis, Invariant, Relation
10+
from mldaikon.invariant.base_cls import (
11+
FailedHypothesis,
12+
Hypothesis,
13+
Invariant,
14+
Relation,
15+
)
16+
from mldaikon.invariant.precondition import find_precondition
1117
from mldaikon.invariant.relation_pool import relation_pool
1218
from mldaikon.trace import MDNONEJSONEncoder, select_trace_implementation
1319
from mldaikon.utils import register_custom_excepthook
@@ -47,6 +53,61 @@ def infer(self, disabled_relations: list[Relation]):
4753
)
4854
return all_invs, all_failed_hypos
4955

56+
def infer_multi_trace(self, disabled_relations: list[Relation]):
57+
hypotheses = self.generate_hypothesis(disabled_relations)
58+
self.collect_examples(hypotheses)
59+
invariants, failed_hypos = self.infer_precondition(hypotheses)
60+
return invariants, failed_hypos
61+
62+
def generate_hypothesis(
63+
self, disabled_relations: list[Relation]
64+
) -> list[list[Hypothesis]]:
65+
all_invs = []
66+
all_failed_hypos = []
67+
hypotheses = []
68+
for trace in self.traces:
69+
for relation in relation_pool:
70+
if disabled_relations is not None and relation in disabled_relations:
71+
logger.info(
72+
f"Skipping relation {relation.__name__} as it is disabled"
73+
)
74+
continue
75+
logger.info(f"Infering invariants for relation: {relation.__name__}")
76+
hypotheses.append(relation.generate_hypothesis(trace))
77+
78+
logger.info(
79+
f"Found {len(invs)} invariants for relation: {relation.__name__}"
80+
)
81+
all_invs.extend(invs)
82+
all_failed_hypos.extend(failed_hypos)
83+
return hypotheses
84+
85+
def collect_examples(self, hypotheses: list[list[Hypothesis]]):
86+
for i, trace in enumerate(self.traces):
87+
for j, hypothesis in enumerate(hypotheses[i]):
88+
if j == i:
89+
# already collected examples for this hypothesis on the same trace that generated it
90+
continue
91+
hypothesis.invariant.relation.collect_examples(trace, hypothesis)
92+
93+
def infer_precondition(self, hypotheses: list[list[Hypothesis]]):
94+
all_hypotheses = []
95+
for trace_hypotheses in hypotheses:
96+
for hypothesis in trace_hypotheses:
97+
all_hypotheses.append(hypothesis)
98+
99+
invariants = []
100+
failed_hypos = []
101+
for hypothesis in all_hypotheses:
102+
precondition = find_precondition(hypothesis, self.traces)
103+
if precondition is None:
104+
failed_hypos.append(FailedHypothesis(hypothesis))
105+
else:
106+
hypothesis.invariant.precondition = precondition
107+
invariants.append(hypothesis.invariant)
108+
109+
return invariants, failed_hypos
110+
50111

51112
def save_invs(invs: list[Invariant], output_file: str):
52113
with open(output_file, "w") as f:

mldaikon/invariant/DistinctArgumentRelation.py

Lines changed: 85 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,34 @@
11
import logging
22
from itertools import combinations
3-
from typing import Any, Dict, List, Set, Tuple, Iterable
3+
from typing import Any, Dict, Iterable, List, Set, Tuple
44

55
from tqdm import tqdm
66

7-
from mldaikon.invariant.base_cls import (
8-
Param,
7+
from mldaikon.invariant.base_cls import ( # GroupedPreconditions,
98
APIParam,
109
CheckerResult,
1110
Example,
1211
ExampleList,
1312
FailedHypothesis,
14-
# GroupedPreconditions,
1513
Hypothesis,
1614
Invariant,
1715
Relation,
1816
)
19-
2017
from mldaikon.invariant.precondition import find_precondition
2118
from mldaikon.trace.trace import Trace
2219

2320
EXP_GROUP_NAME = "distinct_arg"
24-
MAX_FUNC_NUM_CONSECUTIVE_CALL = 6
25-
IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU
21+
MAX_FUNC_NUM_CONSECUTIVE_CALL = 6
22+
IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU
23+
2624

2725
def calculate_IOU(list1, list2):
28-
set1 = set(list1)
29-
set2 = set(list2)
30-
intersection = set1.intersection(set2)
31-
union = set1.union(set2)
32-
iou = len(intersection) / len(union) if len(union) != 0 else 0
33-
return iou
26+
set1 = set(list1)
27+
set2 = set(list2)
28+
intersection = set1.intersection(set2)
29+
union = set1.union(set2)
30+
iou = len(intersection) / len(union) if len(union) != 0 else 0
31+
return iou
3432

3533

3634
def get_func_names_to_deal_with(trace: Trace) -> List[str]:
@@ -54,16 +52,20 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]:
5452

5553
return list(function_pool)
5654

55+
5756
def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]):
58-
listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = (
59-
{}
60-
)
57+
listed_arguments: Dict[
58+
str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]
59+
] = {}
6160
for func_name in function_pool.copy():
6261
func_call_ids = trace.get_func_call_ids(func_name)
6362
keep_this_func = False
6463
for func_call_id in func_call_ids:
6564
event = trace.query_func_call_event(func_call_id)
66-
if (event.pre_record["meta_vars.step"] is None or "args" not in event.pre_record):
65+
if (
66+
event.pre_record["meta_vars.step"] is None
67+
or "args" not in event.pre_record
68+
):
6769
continue
6870
keep_this_func = True
6971
process_id = event.pre_record["process_id"]
@@ -73,24 +75,26 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]):
7375
listed_arguments[func_name] = {}
7476
listed_arguments[func_name][step] = {}
7577
listed_arguments[func_name][step][(process_id, thread_id)] = []
76-
78+
7779
if step not in listed_arguments[func_name]:
7880
listed_arguments[func_name][step] = {}
7981
listed_arguments[func_name][step][(process_id, thread_id)] = []
80-
82+
8183
if (process_id, thread_id) not in listed_arguments[func_name][step]:
8284
listed_arguments[func_name][step][(process_id, thread_id)] = []
83-
84-
listed_arguments[func_name][step][(process_id, thread_id)].append(event.pre_record)
85-
85+
86+
listed_arguments[func_name][step][(process_id, thread_id)].append(
87+
event.pre_record
88+
)
89+
8690
if not keep_this_func:
8791
function_pool.remove(func_name)
8892

8993
return function_pool, listed_arguments
9094

9195

9296
def get_event_list(trace: Trace, function_pool: Iterable[str]):
93-
listed_events: List[dict[str, Any]]= []
97+
listed_events: List[dict[str, Any]] = []
9498
# for all func_ids, get their corresponding events
9599
for func_name in function_pool:
96100
func_call_ids = trace.get_func_call_ids(func_name)
@@ -109,11 +113,16 @@ def get_event_list(trace: Trace, function_pool: Iterable[str]):
109113

110114
return listed_events
111115

112-
def compare_argument(value1, value2, IOU_criteria = True):
116+
117+
def compare_argument(value1, value2, IOU_criteria=True):
113118
if type(value1) != type(value2):
114119
return False
115120
if isinstance(value1, list):
116-
if IOU_criteria and all(isinstance(item, int) for item in value1) and all(isinstance(item, int) for item in value2):
121+
if (
122+
IOU_criteria
123+
and all(isinstance(item, int) for item in value1)
124+
and all(isinstance(item, int) for item in value2)
125+
):
117126
return calculate_IOU(value1, value2) >= IOU_THRESHHOLD
118127
if len(value1) != len(value2):
119128
return False
@@ -134,6 +143,7 @@ def compare_argument(value1, value2, IOU_criteria = True):
134143
return abs(value1 - value2) < 1e-8
135144
return value1 == value2
136145

146+
137147
def is_arguments_list_same(args1: list, args2: list):
138148
if len(args1) != len(args2):
139149
return False
@@ -144,6 +154,7 @@ def is_arguments_list_same(args1: list, args2: list):
144154
return False
145155
return True
146156

157+
147158
# class APIArgsParam(Param):
148159
# def __init__(
149160
# self, api_full_name: str, arg_name: str
@@ -181,7 +192,9 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
181192

182193
# 1. Pre-process all the events
183194
print("Start preprocessing....")
184-
listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = {}
195+
listed_arguments: Dict[
196+
str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]
197+
] = {}
185198
function_pool: Set[Any] = set()
186199

187200
function_pool = set(get_func_names_to_deal_with(trace))
@@ -205,9 +218,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
205218
func_name: Hypothesis(
206219
invariant=Invariant(
207220
relation=DistinctArgumentRelation,
208-
params=[
209-
APIParam(func_name)
210-
],
221+
params=[APIParam(func_name)],
211222
precondition=None,
212223
text_description=f"{func_name} has distinct input arguments on difference PT for each step",
213224
),
@@ -226,29 +237,39 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
226237
for PT_pair1, PT_pair2 in combinations(records.keys(), 2):
227238
for event1 in records[PT_pair1]:
228239
for event2 in records[PT_pair2]:
229-
if(not is_arguments_list_same(event1["args"], event2["args"])):
240+
if not is_arguments_list_same(
241+
event1["args"], event2["args"]
242+
):
230243
flag = True
231244
pos = Example()
232245
pos.add_group(EXP_GROUP_NAME, [event1, event2])
233-
hypothesis_with_examples[func_name].positive_examples.add_example(pos)
246+
hypothesis_with_examples[
247+
func_name
248+
].positive_examples.add_example(pos)
234249
else:
235250
neg = Example()
236251
neg.add_group(EXP_GROUP_NAME, [event1, event2])
237-
hypothesis_with_examples[func_name].negative_examples.add_example(neg)
238-
252+
hypothesis_with_examples[
253+
func_name
254+
].negative_examples.add_example(neg)
255+
239256
for PT_pair in records.keys():
240257
for event1, event2 in combinations(records[PT_pair], 2):
241-
if(not is_arguments_list_same(event1["args"], event2["args"])):
258+
if not is_arguments_list_same(event1["args"], event2["args"]):
242259
flag = True
243260
pos = Example()
244261
pos.add_group(EXP_GROUP_NAME, [event1, event2])
245-
hypothesis_with_examples[func_name].positive_examples.add_example(pos)
262+
hypothesis_with_examples[
263+
func_name
264+
].positive_examples.add_example(pos)
246265
else:
247266
neg = Example()
248267
neg.add_group(EXP_GROUP_NAME, [event1, event2])
249-
hypothesis_with_examples[func_name].negative_examples.add_example(neg)
250-
251-
if(not flag):
268+
hypothesis_with_examples[
269+
func_name
270+
].negative_examples.add_example(neg)
271+
272+
if not flag:
252273
hypothesis_with_examples.pop(func_name)
253274

254275
print("End adding examples")
@@ -261,13 +282,11 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
261282
logger.debug(
262283
f"Finding Precondition for {hypo}: {hypothesis_with_examples[hypo].invariant.text_description}"
263284
)
264-
preconditions = find_precondition(hypothesis_with_examples[hypo], trace)
285+
preconditions = find_precondition(hypothesis_with_examples[hypo], [trace])
265286
logger.debug(f"Preconditions for {hypo}:\n{str(preconditions)}")
266287

267288
if preconditions is not None:
268-
hypothesis_with_examples[hypo].invariant.precondition = (
269-
preconditions
270-
)
289+
hypothesis_with_examples[hypo].invariant.precondition = preconditions
271290
else:
272291
logger.debug(f"Precondition not found for {hypo}")
273292
failed_hypothesis.append(
@@ -281,13 +300,10 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
281300
print("End precondition inference")
282301

283302
return (
284-
list(
285-
[hypo.invariant for hypo in hypothesis_with_examples.values()]
286-
),
303+
list([hypo.invariant for hypo in hypothesis_with_examples.values()]),
287304
failed_hypothesis,
288305
)
289306

290-
291307
@staticmethod
292308
def evaluate(value_group: list) -> bool:
293309
"""Given a group of values, should return a boolean value
@@ -317,13 +333,13 @@ def static_check_all(
317333

318334
# 1. Pre-process all the events
319335
print("Start preprocessing....")
320-
listed_arguments: Dict[str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]] = {}
336+
listed_arguments: Dict[
337+
str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]]
338+
] = {}
321339
function_pool: Set[Any] = set()
322-
func= inv.params[0]
340+
func = inv.params[0]
323341

324-
assert isinstance(
325-
func, APIParam
326-
), "Invariant parameters should be APIParam."
342+
assert isinstance(func, APIParam), "Invariant parameters should be APIParam."
327343

328344
func_name = func.api_full_name
329345
function_pool.add(func_name)
@@ -352,30 +368,34 @@ def static_check_all(
352368
)
353369

354370
for step, records in listed_arguments[func_name].items():
355-
for PT_pair1, PT_pair2 in combinations(records.keys(), 2):
356-
for event1 in records[PT_pair1]:
357-
for event2 in records[PT_pair2]:
358-
if(is_arguments_list_same(event1["args"], event2["args"])):
359-
return CheckerResult(
360-
trace=[event1, event2],
361-
invariant=inv,
362-
check_passed=True,
363-
triggered=True,
364-
)
365-
366-
for PT_pair in records.keys():
367-
for event1, event2 in combinations(records[PT_pair], 2):
368-
if(is_arguments_list_same(event1["args"], event2["args"])):
371+
for PT_pair1, PT_pair2 in combinations(records.keys(), 2):
372+
for event1 in records[PT_pair1]:
373+
for event2 in records[PT_pair2]:
374+
if is_arguments_list_same(event1["args"], event2["args"]):
369375
return CheckerResult(
370376
trace=[event1, event2],
371377
invariant=inv,
372378
check_passed=True,
373379
triggered=True,
374380
)
375381

382+
for PT_pair in records.keys():
383+
for event1, event2 in combinations(records[PT_pair], 2):
384+
if is_arguments_list_same(event1["args"], event2["args"]):
385+
return CheckerResult(
386+
trace=[event1, event2],
387+
invariant=inv,
388+
check_passed=True,
389+
triggered=True,
390+
)
391+
376392
return CheckerResult(
377393
trace=None,
378394
invariant=inv,
379395
check_passed=True,
380396
triggered=True,
381397
)
398+
399+
@staticmethod
400+
def get_precondition_infer_keys_to_skip(hypothesis: Hypothesis) -> list[str]:
401+
return []

0 commit comments

Comments
 (0)