diff --git a/examples/imp/imp.k b/examples/imp/imp.k index 092b45e37b..f5b2bee615 100644 --- a/examples/imp/imp.k +++ b/examples/imp/imp.k @@ -53,6 +53,6 @@ module IMP rule while (B) S => if (B) {S while (B) S} else {} [structural] rule int (X,Xs => Xs);_ Rho:Map (.Map => X|->0) - // requires notBool (X in keys(Rho)) + requires notBool (X in keys(Rho)) rule int .Ids; S => S [structural] -endmodule \ No newline at end of file +endmodule diff --git a/ml/rewrite/rewrite.py b/ml/rewrite/rewrite.py index 96eddf4f00..a5076a457b 100644 --- a/ml/rewrite/rewrite.py +++ b/ml/rewrite/rewrite.py @@ -42,6 +42,8 @@ def __init__(self, env: KoreComposer): "LblnotBool'Unds'": BooleanNotEvaluator(env), "Lbl'UndsEqlsEqls'K'Unds'": KEqualityEvaluator(env), "Lbl'UndsEqlsSlshEqls'K'Unds'": KNotEqualityEvaluator(env), + "Lblkeys'LParUndsRParUnds'MAP'Unds'Set'Unds'Map": MapKeysEvaluator(env), + "LblSet'Coln'in": SetInEvaluator(env), "LblMap'Coln'lookup": MapLookupEvaluator(env), } self.disjoint_gen = DisjointnessProofGenerator(env) @@ -100,8 +102,7 @@ def rewrite_from_pattern( unification_gen = UnificationProofGenerator(self.composer) if rule_hint is not None: - assert rule_hint in self.composer.rewrite_axioms, \ - f"unable to find axiom with id {rule_hint} in the hint" + assert (rule_hint in self.composer.rewrite_axioms), f"unable to find axiom with id {rule_hint} in the hint" axioms = [self.composer.rewrite_axioms[rule_hint]] else: axioms = list(self.composer.rewrite_axioms.values()) @@ -194,8 +195,9 @@ def prove_rewriting_step( self.composer.encode_pattern(initial_pattern) - assert len(rewriting_step.applied_rules) == 1 and len(rewriting_step.remainders) == 0, \ - "non-determinism not supported" + assert ( + len(rewriting_step.applied_rules) == 1 and len(rewriting_step.remainders) == 0 + ), "non-determinism not supported" rule = rewriting_step.applied_rules[0] @@ -220,11 +222,11 @@ def check_equal_or_unify(self, given: kore.Pattern, expected: kore.Pattern) -> O return None unification_result = UnificationProofGenerator(self.composer).unify_patterns(expected, given) - assert unification_result is not None, \ - f"expecting the following patterns to be equal or unifiable: {given} and {expected}" + assert ( + unification_result is not None + ), f"expecting the following patterns to be equal or unifiable: {given} and {expected}" - assert len(unification_result.substitution) == 0, \ - "patterns should be concrete" + assert len(unification_result.substitution) == 0, "patterns should be concrete" simplification_claim = self.apply_reflexivity(expected) @@ -278,8 +280,9 @@ def prove_rewriting_task( step_claim = self.prove_rewriting_step(step) lhs, rhs = self.decompose_concrete_rewrite_claim(step_claim) - assert step_initial == lhs, \ - f"unexpected rewriting claim, expected to rewrite from {step_initial}, but got {lhs}" + assert ( + step_initial == lhs + ), f"unexpected rewriting claim, expected to rewrite from {step_initial}, but got {lhs}" self.composer.load_comment(f"\nrewriting step:\n{lhs}\n=>\n{rhs}\n") step_claim = self.composer.load_provable_claim_as_theorem( @@ -438,21 +441,20 @@ def prove_negation_requires(self, pattern: kore.Pattern) -> Optional[Proof]: # the following chunk of nonsense basically # checks that body is of the form # top /\ ( /\ top) - if isinstance(body, kore.MLPattern) and \ - body.construct == kore.MLPattern.AND and \ - isinstance(body.arguments[0], kore.MLPattern) and \ - body.arguments[0].construct == kore.MLPattern.TOP and \ - isinstance(body.arguments[1], kore.MLPattern) and \ - body.arguments[1].construct == kore.MLPattern.AND and \ - isinstance(body.arguments[1].arguments[1], kore.MLPattern) and \ - body.arguments[1].arguments[1].construct == kore.MLPattern.TOP: + if (isinstance(body, kore.MLPattern) and body.construct == kore.MLPattern.AND + and isinstance(body.arguments[0], kore.MLPattern) + and body.arguments[0].construct == kore.MLPattern.TOP + and isinstance(body.arguments[1], kore.MLPattern) + and body.arguments[1].construct == kore.MLPattern.AND + and isinstance(body.arguments[1].arguments[1], kore.MLPattern) + and body.arguments[1].arguments[1].construct == kore.MLPattern.TOP): inner_condition = body.arguments[1].arguments[0] if isinstance(inner_condition, kore.MLPattern): - if inner_condition.construct == kore.MLPattern.CEIL and \ - isinstance(inner_condition.arguments[0], kore.MLPattern) and \ - inner_condition.arguments[0].construct == kore.MLPattern.AND: + if (inner_condition.construct == kore.MLPattern.CEIL + and isinstance(inner_condition.arguments[0], kore.MLPattern) + and inner_condition.arguments[0].construct == kore.MLPattern.AND): # ( \ceil ( \and ) ) lemma = "owise-var-1-cond-0" left, right = inner_condition.arguments[0].arguments @@ -461,10 +463,13 @@ def prove_negation_requires(self, pattern: kore.Pattern) -> Optional[Proof]: lemma = "owise-var-1-cond-0-alt" left, right = inner_condition.arguments else: - assert False, f"expecting disjointness condition, got {inner_condition}" + assert (False), f"expecting disjointness condition, got {inner_condition}" output_sort = KoreUtils.infer_sort(pattern) - claim = kore.Claim([], kore.MLPattern(kore.MLPattern.NOT, [output_sort], [pattern])) + claim = kore.Claim( + [], + kore.MLPattern(kore.MLPattern.NOT, [output_sort], [pattern]), + ) claim.resolve(self.composer.module) print("> proving disjointness claim") @@ -538,8 +543,8 @@ def resolve_unification_obligations_in_requires_clause(self, requires: kore.Patt if requires.construct == kore.MLPattern.AND: # resolve each side separately and combine left, right = requires.arguments - left_unification = self.resolve_unification_obligations_in_requires_clause(left) - right_unification = self.resolve_unification_obligations_in_requires_clause(right) + left_unification = (self.resolve_unification_obligations_in_requires_clause(left)) + right_unification = (self.resolve_unification_obligations_in_requires_clause(right)) if left_unification is None or right_unification is None: return None @@ -573,7 +578,7 @@ def match_and_instantiate_anywhere_axiom(self, axiom: ProvableClaim, # there could be more unification conditions in the require clause requires_substituted = KoreUtils.copy_and_substitute_pattern(requires, unification_result.substitution) - side_unification_result = self.resolve_unification_obligations_in_requires_clause(requires_substituted) + side_unification_result = (self.resolve_unification_obligations_in_requires_clause(requires_substituted)) if side_unification_result is None: return None @@ -589,8 +594,8 @@ def match_and_instantiate_anywhere_axiom(self, axiom: ProvableClaim, return None # eliminate all universal quantifiers - instantiated_axiom = QuantifierProofGenerator(self.composer) \ - .prove_forall_elim(axiom, unification_result.substitution) + instantiated_axiom = QuantifierProofGenerator(self.composer + ).prove_forall_elim(axiom, unification_result.substitution) # apply equations used in unification for equation, path in unification_result.applied_equations: @@ -609,7 +614,7 @@ def match_and_instantiate_anywhere_axiom(self, axiom: ProvableClaim, if requires_proof is None: return None - sort_param, = instantiated_axiom.claim.sort_variables + (sort_param, ) = instantiated_axiom.claim.sort_variables encoded_sort_param = self.composer.encode_pattern(sort_param) removed_requires = self.composer.get_theorem("kore-mp-v1").apply( @@ -808,6 +813,9 @@ class InnermostFunctionPathVisitor(KoreVisitor[Union[kore.Pattern, kore.Axiom], "Lbl'UndsPipe'-'-GT-Unds'", "Lbl'Unds'Map'Unds'", "Lbl'Stop'Map", + "Lbl'Stop'Set", + "LblSetItem", + "Lbl'Unds'Set'Unds", } def postvisit_variable(self, variable: kore.Variable) -> Optional[PatternPath]: @@ -867,8 +875,7 @@ def build_equation(self, application: kore.Application, result: kore.Pattern) -> sort_var = kore.SortVariable("R") output_sort = KoreUtils.infer_sort(application) - assert output_sort == KoreUtils.infer_sort(result), \ - f"result {result} has a different sort than {application}" + assert output_sort == KoreUtils.infer_sort(result), f"result {result} has a different sort than {application}" claim = kore.Claim( [sort_var], @@ -1013,6 +1020,71 @@ def prove_evaluation(self, application: kore.Application) -> ProvableClaim: assert isinstance(map_pattern, kore.Application) found = self.lookup(map_pattern, key_pattern) - assert found is not None, f"key {key_pattern} does not exist in map pattern {map_pattern}" + assert (found is not None), f"key {key_pattern} does not exist in map pattern {map_pattern}" return self.build_equation(application, found) + + +class SetInEvaluator(BuiltinFunctionEvaluator): + def is_element(self, app: kore.Application, element: kore.Pattern) -> bool: + + if KoreTemplates.is_set_unit_pattern(app): + return False + elif KoreTemplates.is_set_merge_pattern(app): + + left, right = app.arguments + assert isinstance(left, kore.Application) + assert isinstance(right, kore.Application) + + element_in_left = self.is_element(left, element) + element_in_right = self.is_element(right, element) + + return element_in_left or element_in_right + elif KoreTemplates.is_set_singleton_pattern(app): + found_element = app.arguments + + return element == found_element + + return False + + def prove_evaluation(self, application: kore.Application) -> ProvableClaim: + key, set_pattern = application.arguments + + assert isinstance(set_pattern, kore.Application) + return self.build_arithmetic_equation(application, self.is_element(set_pattern, key)) + + +class MapKeysEvaluator(BuiltinFunctionEvaluator): + def get_keys(self, app: kore.Application) -> kore.Pattern: + if KoreTemplates.is_map_unit_pattern(app): + unit_symbol = self.composer.module.get_symbol_by_name("Lbl'Stop'Set") + assert unit_symbol is not None + unit_symbol_instance = kore.SymbolInstance(unit_symbol, []) + return kore.Application(unit_symbol_instance, []) + elif KoreTemplates.is_map_merge_pattern(app): + + left, right = app.arguments + merge_symbol = self.composer.module.get_symbol_by_name("Lbl'Unds'Set'Unds") + assert isinstance(left, kore.Application) + assert isinstance(right, kore.Application) + assert merge_symbol is not None + + left_keys = self.get_keys(left) + right_keys = self.get_keys(right) + + merge_symbol_instance = kore.SymbolInstance(merge_symbol, []) + return kore.Application(merge_symbol_instance, [left_keys, right_keys]) + else: + assert KoreTemplates.is_map_mapsto_pattern(app) + key, value = app.arguments + + singleton_symbol = self.composer.module.get_symbol_by_name("LblSetItem") + assert singleton_symbol is not None + singleton_symbol_instance = kore.SymbolInstance(singleton_symbol, []) + return kore.Application(singleton_symbol_instance, [key]) + + def prove_evaluation(self, application: kore.Application) -> ProvableClaim: + map_pattern = application.arguments[0] + + assert isinstance(map_pattern, kore.Application) + return self.build_equation(application, self.get_keys(map_pattern)) diff --git a/ml/rewrite/templates.py b/ml/rewrite/templates.py index 3893456dbf..7907fdaaf7 100644 --- a/ml/rewrite/templates.py +++ b/ml/rewrite/templates.py @@ -78,7 +78,7 @@ def get_symbol_of_equational_axiom(axiom: kore.Axiom, ) -> Optional[kore.SymbolI return eqn_lhs.symbol @staticmethod - def get_sorts_of_subsort_axiom(axiom: kore.Axiom) -> Optional[Tuple[kore.SortInstance, kore.SortInstance]]: + def get_sorts_of_subsort_axiom(axiom: kore.Axiom, ) -> Optional[Tuple[kore.SortInstance, kore.SortInstance]]: attribute = axiom.get_attribute_by_symbol("subsort") if attribute is None: return None @@ -104,7 +104,7 @@ def get_axiom_unique_id(axiom: kore.Axiom) -> Optional[str]: return id_term.arguments[0].content @staticmethod - def get_symbol_of_functional_axiom(axiom: kore.Axiom) -> Optional[kore.SymbolInstance]: + def get_symbol_of_functional_axiom(axiom: kore.Axiom, ) -> Optional[kore.SymbolInstance]: """ Get the corresponding symbol instance of the given functional axiom """ @@ -134,7 +134,7 @@ def get_symbol_of_functional_axiom(axiom: kore.Axiom) -> Optional[kore.SymbolIns return rhs.symbol @staticmethod - def get_sort_symbol_of_no_junk_axiom(axiom: kore.Axiom) -> Optional[kore.SortInstance]: + def get_sort_symbol_of_no_junk_axiom(axiom: kore.Axiom, ) -> Optional[kore.SortInstance]: """ A no junk axiom should be a disjunction of existential patterns """ @@ -156,7 +156,7 @@ def get_sort_symbol_of_no_junk_axiom(axiom: kore.Axiom) -> Optional[kore.SortIns return sort @staticmethod - def get_symbol_for_no_confusion_same_constructor_axiom(axiom: kore.Axiom) -> Optional[kore.SymbolInstance]: + def get_symbol_for_no_confusion_same_constructor_axiom(axiom: kore.Axiom, ) -> Optional[kore.SymbolInstance]: r""" Axiom of the form f(ph1, ..., phn) /\ f(ph1', ..., phn') => f(ph1 /\ ph1', ..., phn /\ phn') @@ -190,6 +190,31 @@ def get_symbols_for_no_confusion_different_constructor_axiom( return left.symbol, right.symbol + ################################### + # Utils methods for set patterns. # + ################################### + + @staticmethod + def is_set_merge_pattern(pattern: kore.Pattern) -> bool: + return ( + isinstance(pattern, kore.Application) and pattern.symbol.get_symbol_name() == "Lbl'Unds'Set'Unds'" + and len(pattern.arguments) == 2 + ) + + @staticmethod + def is_set_singleton_pattern(pattern: kore.Pattern) -> bool: + return ( + isinstance(pattern, kore.Application) and pattern.symbol.get_symbol_name() == "LblSetItem" + and len(pattern.arguments) == 1 + ) + + @staticmethod + def is_set_unit_pattern(pattern: kore.Pattern) -> bool: + return ( + isinstance(pattern, kore.Application) and pattern.symbol.get_symbol_name() == "Lbl'Stop'Set" + and len(pattern.arguments) == 0 + ) + ################################### # Utils methods for map patterns. # ################################### @@ -222,13 +247,13 @@ def is_map_pattern(pattern: kore.Pattern) -> bool: phi ::= merge(phi1, phi2) | phi1 |-> phi2 | .map """ - if KoreTemplates.is_map_unit_pattern(pattern) or \ - KoreTemplates.is_map_mapsto_pattern(pattern): + if KoreTemplates.is_map_unit_pattern(pattern) or KoreTemplates.is_map_mapsto_pattern(pattern): return True if KoreTemplates.is_map_merge_pattern(pattern): - return KoreTemplates.is_map_pattern(KoreTemplates.get_map_merge_left(pattern)) and \ - KoreTemplates.is_map_pattern(KoreTemplates.get_map_merge_right(pattern)) + return KoreTemplates.is_map_pattern( + KoreTemplates.get_map_merge_left(pattern) + ) and KoreTemplates.is_map_pattern(KoreTemplates.get_map_merge_right(pattern)) return False @@ -248,8 +273,10 @@ def get_map_merge_right(pattern: kore.Pattern) -> kore.Pattern: def in_place_swap_map_merge_pattern(pattern: kore.Pattern) -> None: assert KoreTemplates.is_map_merge_pattern(pattern) assert isinstance(pattern, kore.Application) - pattern.arguments[0], pattern.arguments[1] = \ - pattern.arguments[1], pattern.arguments[0] + pattern.arguments[0], pattern.arguments[1] = ( + pattern.arguments[1], + pattern.arguments[0], + ) @staticmethod def in_place_rotate_right_map_merge_pattern(pattern: kore.Pattern) -> None: @@ -272,7 +299,7 @@ def in_place_rotate_right_map_merge_pattern(pattern: kore.Pattern) -> None: left.arguments[1] = right @staticmethod - def get_path_to_smallest_key_in_map_pattern(pattern: kore.Pattern) -> Tuple[kore.Pattern, PatternPath]: + def get_path_to_smallest_key_in_map_pattern(pattern: kore.Pattern, ) -> Tuple[kore.Pattern, PatternPath]: r""" Return the path to the pattern with the smallest key. 0 means "left branch" and 1 means "right branch". @@ -281,12 +308,10 @@ def get_path_to_smallest_key_in_map_pattern(pattern: kore.Pattern) -> Tuple[kore assert KoreTemplates.is_map_pattern(pattern) assert isinstance(pattern, kore.Application) - if KoreTemplates.is_map_mapsto_pattern(pattern) or \ - KoreTemplates.is_map_unit_pattern(pattern): + if KoreTemplates.is_map_mapsto_pattern(pattern) or KoreTemplates.is_map_unit_pattern(pattern): return pattern, [] - assert KoreTemplates.is_map_merge_pattern(pattern), \ - f"expecting a map merge, got {pattern}" + assert KoreTemplates.is_map_merge_pattern(pattern), f"expecting a map merge, got {pattern}" lhs = KoreTemplates.get_map_merge_left(pattern) rhs = KoreTemplates.get_map_merge_right(pattern)