11import logging
22from itertools import combinations
3- from typing import Any , Dict , List , Set , Tuple , Iterable
3+ from typing import Any , Dict , Iterable , List , Set , Tuple
44
55from tqdm import tqdm
66
7- from mldaikon .invariant .base_cls import (
8- Param ,
7+ from mldaikon .invariant .base_cls import ( # GroupedPreconditions,
98 APIParam ,
109 CheckerResult ,
1110 Example ,
1211 ExampleList ,
1312 FailedHypothesis ,
14- # GroupedPreconditions,
1513 Hypothesis ,
1614 Invariant ,
1715 Relation ,
1816)
19-
2017from mldaikon .invariant .precondition import find_precondition
2118from mldaikon .trace .trace import Trace
2219
2320EXP_GROUP_NAME = "distinct_arg"
24- MAX_FUNC_NUM_CONSECUTIVE_CALL = 6
25- IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU
21+ MAX_FUNC_NUM_CONSECUTIVE_CALL = 6
22+ IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU
23+
2624
2725def calculate_IOU (list1 , list2 ):
28- set1 = set (list1 )
29- set2 = set (list2 )
30- intersection = set1 .intersection (set2 )
31- union = set1 .union (set2 )
32- iou = len (intersection ) / len (union ) if len (union ) != 0 else 0
33- return iou
26+ set1 = set (list1 )
27+ set2 = set (list2 )
28+ intersection = set1 .intersection (set2 )
29+ union = set1 .union (set2 )
30+ iou = len (intersection ) / len (union ) if len (union ) != 0 else 0
31+ return iou
3432
3533
3634def get_func_names_to_deal_with (trace : Trace ) -> List [str ]:
@@ -54,16 +52,20 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]:
5452
5553 return list (function_pool )
5654
55+
5756def get_event_data_per_function_per_step (trace : Trace , function_pool : Set [Any ]):
58- listed_arguments : Dict [str , Dict [ int , Dict [ Tuple [ str , str ], List [ dict [ str , Any ]]]]] = (
59- {}
60- )
57+ listed_arguments : Dict [
58+ str , Dict [ int , Dict [ Tuple [ str , str ], List [ dict [ str , Any ]]]]
59+ ] = {}
6160 for func_name in function_pool .copy ():
6261 func_call_ids = trace .get_func_call_ids (func_name )
6362 keep_this_func = False
6463 for func_call_id in func_call_ids :
6564 event = trace .query_func_call_event (func_call_id )
66- if (event .pre_record ["meta_vars.step" ] is None or "args" not in event .pre_record ):
65+ if (
66+ event .pre_record ["meta_vars.step" ] is None
67+ or "args" not in event .pre_record
68+ ):
6769 continue
6870 keep_this_func = True
6971 process_id = event .pre_record ["process_id" ]
@@ -73,24 +75,26 @@ def get_event_data_per_function_per_step(trace: Trace, function_pool: Set[Any]):
7375 listed_arguments [func_name ] = {}
7476 listed_arguments [func_name ][step ] = {}
7577 listed_arguments [func_name ][step ][(process_id , thread_id )] = []
76-
78+
7779 if step not in listed_arguments [func_name ]:
7880 listed_arguments [func_name ][step ] = {}
7981 listed_arguments [func_name ][step ][(process_id , thread_id )] = []
80-
82+
8183 if (process_id , thread_id ) not in listed_arguments [func_name ][step ]:
8284 listed_arguments [func_name ][step ][(process_id , thread_id )] = []
83-
84- listed_arguments [func_name ][step ][(process_id , thread_id )].append (event .pre_record )
85-
85+
86+ listed_arguments [func_name ][step ][(process_id , thread_id )].append (
87+ event .pre_record
88+ )
89+
8690 if not keep_this_func :
8791 function_pool .remove (func_name )
8892
8993 return function_pool , listed_arguments
9094
9195
9296def get_event_list (trace : Trace , function_pool : Iterable [str ]):
93- listed_events : List [dict [str , Any ]]= []
97+ listed_events : List [dict [str , Any ]] = []
9498 # for all func_ids, get their corresponding events
9599 for func_name in function_pool :
96100 func_call_ids = trace .get_func_call_ids (func_name )
@@ -109,11 +113,16 @@ def get_event_list(trace: Trace, function_pool: Iterable[str]):
109113
110114 return listed_events
111115
112- def compare_argument (value1 , value2 , IOU_criteria = True ):
116+
117+ def compare_argument (value1 , value2 , IOU_criteria = True ):
113118 if type (value1 ) != type (value2 ):
114119 return False
115120 if isinstance (value1 , list ):
116- if IOU_criteria and all (isinstance (item , int ) for item in value1 ) and all (isinstance (item , int ) for item in value2 ):
121+ if (
122+ IOU_criteria
123+ and all (isinstance (item , int ) for item in value1 )
124+ and all (isinstance (item , int ) for item in value2 )
125+ ):
117126 return calculate_IOU (value1 , value2 ) >= IOU_THRESHHOLD
118127 if len (value1 ) != len (value2 ):
119128 return False
@@ -134,6 +143,7 @@ def compare_argument(value1, value2, IOU_criteria = True):
134143 return abs (value1 - value2 ) < 1e-8
135144 return value1 == value2
136145
146+
137147def is_arguments_list_same (args1 : list , args2 : list ):
138148 if len (args1 ) != len (args2 ):
139149 return False
@@ -144,6 +154,7 @@ def is_arguments_list_same(args1: list, args2: list):
144154 return False
145155 return True
146156
157+
147158# class APIArgsParam(Param):
148159# def __init__(
149160# self, api_full_name: str, arg_name: str
@@ -181,7 +192,9 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
181192
182193 # 1. Pre-process all the events
183194 print ("Start preprocessing...." )
184- listed_arguments : Dict [str , Dict [int , Dict [Tuple [str , str ], List [dict [str , Any ]]]]] = {}
195+ listed_arguments : Dict [
196+ str , Dict [int , Dict [Tuple [str , str ], List [dict [str , Any ]]]]
197+ ] = {}
185198 function_pool : Set [Any ] = set ()
186199
187200 function_pool = set (get_func_names_to_deal_with (trace ))
@@ -205,9 +218,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
205218 func_name : Hypothesis (
206219 invariant = Invariant (
207220 relation = DistinctArgumentRelation ,
208- params = [
209- APIParam (func_name )
210- ],
221+ params = [APIParam (func_name )],
211222 precondition = None ,
212223 text_description = f"{ func_name } has distinct input arguments on difference PT for each step" ,
213224 ),
@@ -226,29 +237,39 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
226237 for PT_pair1 , PT_pair2 in combinations (records .keys (), 2 ):
227238 for event1 in records [PT_pair1 ]:
228239 for event2 in records [PT_pair2 ]:
229- if (not is_arguments_list_same (event1 ["args" ], event2 ["args" ])):
240+ if not is_arguments_list_same (
241+ event1 ["args" ], event2 ["args" ]
242+ ):
230243 flag = True
231244 pos = Example ()
232245 pos .add_group (EXP_GROUP_NAME , [event1 , event2 ])
233- hypothesis_with_examples [func_name ].positive_examples .add_example (pos )
246+ hypothesis_with_examples [
247+ func_name
248+ ].positive_examples .add_example (pos )
234249 else :
235250 neg = Example ()
236251 neg .add_group (EXP_GROUP_NAME , [event1 , event2 ])
237- hypothesis_with_examples [func_name ].negative_examples .add_example (neg )
238-
252+ hypothesis_with_examples [
253+ func_name
254+ ].negative_examples .add_example (neg )
255+
239256 for PT_pair in records .keys ():
240257 for event1 , event2 in combinations (records [PT_pair ], 2 ):
241- if ( not is_arguments_list_same (event1 ["args" ], event2 ["args" ]) ):
258+ if not is_arguments_list_same (event1 ["args" ], event2 ["args" ]):
242259 flag = True
243260 pos = Example ()
244261 pos .add_group (EXP_GROUP_NAME , [event1 , event2 ])
245- hypothesis_with_examples [func_name ].positive_examples .add_example (pos )
262+ hypothesis_with_examples [
263+ func_name
264+ ].positive_examples .add_example (pos )
246265 else :
247266 neg = Example ()
248267 neg .add_group (EXP_GROUP_NAME , [event1 , event2 ])
249- hypothesis_with_examples [func_name ].negative_examples .add_example (neg )
250-
251- if (not flag ):
268+ hypothesis_with_examples [
269+ func_name
270+ ].negative_examples .add_example (neg )
271+
272+ if not flag :
252273 hypothesis_with_examples .pop (func_name )
253274
254275 print ("End adding examples" )
@@ -261,13 +282,11 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
261282 logger .debug (
262283 f"Finding Precondition for { hypo } : { hypothesis_with_examples [hypo ].invariant .text_description } "
263284 )
264- preconditions = find_precondition (hypothesis_with_examples [hypo ], trace )
285+ preconditions = find_precondition (hypothesis_with_examples [hypo ], [ trace ] )
265286 logger .debug (f"Preconditions for { hypo } :\n { str (preconditions )} " )
266287
267288 if preconditions is not None :
268- hypothesis_with_examples [hypo ].invariant .precondition = (
269- preconditions
270- )
289+ hypothesis_with_examples [hypo ].invariant .precondition = preconditions
271290 else :
272291 logger .debug (f"Precondition not found for { hypo } " )
273292 failed_hypothesis .append (
@@ -281,13 +300,10 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]:
281300 print ("End precondition inference" )
282301
283302 return (
284- list (
285- [hypo .invariant for hypo in hypothesis_with_examples .values ()]
286- ),
303+ list ([hypo .invariant for hypo in hypothesis_with_examples .values ()]),
287304 failed_hypothesis ,
288305 )
289306
290-
291307 @staticmethod
292308 def evaluate (value_group : list ) -> bool :
293309 """Given a group of values, should return a boolean value
@@ -317,13 +333,13 @@ def static_check_all(
317333
318334 # 1. Pre-process all the events
319335 print ("Start preprocessing...." )
320- listed_arguments : Dict [str , Dict [int , Dict [Tuple [str , str ], List [dict [str , Any ]]]]] = {}
336+ listed_arguments : Dict [
337+ str , Dict [int , Dict [Tuple [str , str ], List [dict [str , Any ]]]]
338+ ] = {}
321339 function_pool : Set [Any ] = set ()
322- func = inv .params [0 ]
340+ func = inv .params [0 ]
323341
324- assert isinstance (
325- func , APIParam
326- ), "Invariant parameters should be APIParam."
342+ assert isinstance (func , APIParam ), "Invariant parameters should be APIParam."
327343
328344 func_name = func .api_full_name
329345 function_pool .add (func_name )
@@ -352,30 +368,34 @@ def static_check_all(
352368 )
353369
354370 for step , records in listed_arguments [func_name ].items ():
355- for PT_pair1 , PT_pair2 in combinations (records .keys (), 2 ):
356- for event1 in records [PT_pair1 ]:
357- for event2 in records [PT_pair2 ]:
358- if (is_arguments_list_same (event1 ["args" ], event2 ["args" ])):
359- return CheckerResult (
360- trace = [event1 , event2 ],
361- invariant = inv ,
362- check_passed = True ,
363- triggered = True ,
364- )
365-
366- for PT_pair in records .keys ():
367- for event1 , event2 in combinations (records [PT_pair ], 2 ):
368- if (is_arguments_list_same (event1 ["args" ], event2 ["args" ])):
371+ for PT_pair1 , PT_pair2 in combinations (records .keys (), 2 ):
372+ for event1 in records [PT_pair1 ]:
373+ for event2 in records [PT_pair2 ]:
374+ if is_arguments_list_same (event1 ["args" ], event2 ["args" ]):
369375 return CheckerResult (
370376 trace = [event1 , event2 ],
371377 invariant = inv ,
372378 check_passed = True ,
373379 triggered = True ,
374380 )
375381
382+ for PT_pair in records .keys ():
383+ for event1 , event2 in combinations (records [PT_pair ], 2 ):
384+ if is_arguments_list_same (event1 ["args" ], event2 ["args" ]):
385+ return CheckerResult (
386+ trace = [event1 , event2 ],
387+ invariant = inv ,
388+ check_passed = True ,
389+ triggered = True ,
390+ )
391+
376392 return CheckerResult (
377393 trace = None ,
378394 invariant = inv ,
379395 check_passed = True ,
380396 triggered = True ,
381397 )
398+
399+ @staticmethod
400+ def get_precondition_infer_keys_to_skip (hypothesis : Hypothesis ) -> list [str ]:
401+ return []
0 commit comments