Skip to content

Commit 6c89592

Browse files
tobiasgrosserineolarthur-adjedjjavra
authored
Lean: add matchbv support (rems-project#1222)
The only semantic change this PR proposes is to not assert on argument pattern that we cannot yet translate, but instead to record them as `TODO_ARG_PATTERN`. Co-authored-by: Léo Stefanesco <[email protected]> Co-authored-by: arthur-adjedj <[email protected]> Co-authored-by: Jakob von Raumer <[email protected]>
1 parent 501a5a5 commit 6c89592

32 files changed

+436
-410
lines changed

src/lib/rewrites.ml

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -913,92 +913,32 @@ let pats_complete l env ps typ =
913913
PC.is_complete l ctx ps typ
914914

915915
(* Rewrite guarded patterns into a combination of if-expressions and
916-
unguarded pattern matches
917-
918-
Strategy:
919-
- Split clauses into groups where the first pattern subsumes all the
920-
following ones
921-
- Translate the groups in reverse order, using the next group as a
922-
fall-through target, if there is one
923-
- Within a group,
924-
- translate the sequence of clauses to an if-then-else cascade using the
925-
guards as long as the patterns are equivalent modulo substitution, or
926-
- recursively translate the remaining clauses to a pattern match if
927-
there is a difference in the patterns.
916+
unguarded pattern matches
917+
918+
If [fun_only] is [true], do not rewrite bitvector patterns that are not
919+
function parameters. The motivation is that the Lean backend can natively
920+
handle those.
921+
922+
Strategy:
923+
- Split clauses into groups where the first pattern subsumes all the
924+
following ones
925+
- Translate the groups in reverse order, using the next group as a
926+
fall-through target, if there is one
927+
- Within a group,
928+
- translate the sequence of clauses to an if-then-else cascade using the
929+
guards as long as the patterns are equivalent modulo substitution, or
930+
- recursively translate the remaining clauses to a pattern match if
931+
there is a difference in the patterns.
928932
929933
TODO: Compare this more closely with the algorithm in the CPP'18 paper of
930934
Spector-Zabusky et al, who seem to use the opposite grouping and merging
931935
strategy to ours: group *mutually exclusive* clauses, and try to merge them
932936
into a pattern match first instead of an if-then-else cascade.
933937
*)
934-
let rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ
938+
let rewrite_toplevel_guarded_clauses fun_only mk_fallthrough l env pat_typ typ
935939
(cs : (tannot pat * tannot exp option * tannot exp * tannot clause_annot) list) =
936940
let annot_from_clause (def_annot, tannot) = (def_annot.loc, tannot) in
937941
let fix_fallthrough (pat, guard, exp, (l, tannot)) = (pat, guard, exp, (mk_def_annot l (), tannot)) in
938-
939-
let rec group fallthrough clauses =
940-
let add_clause (pat, cls, annot) c = (pat, cls @ [c], annot) in
941-
let rec group_aux current acc = function
942-
| ((pat, guard, body, annot) as c) :: cs -> (
943-
let current_pat, _, _ = current in
944-
match subsumes_pat current_pat pat with
945-
| Some substs ->
946-
let pat' = List.fold_left subst_id_pat pat substs in
947-
let guard' =
948-
match guard with Some exp -> Some (List.fold_left subst_id_exp exp substs) | None -> None
949-
in
950-
let body' = List.fold_left subst_id_exp body substs in
951-
let c' = (pat', guard', body', annot) in
952-
group_aux (add_clause current c') acc cs
953-
| None ->
954-
let pat = match cs with _ :: _ -> remove_wildcards "g__" pat | _ -> pat in
955-
group_aux (pat, [c], annot_from_clause annot) (acc @ [current]) cs
956-
)
957-
| [] -> acc @ [current]
958-
in
959-
let groups =
960-
match clauses with
961-
| [((pat, guard, body, annot) as c)] -> [(pat, [c], annot_from_clause annot)]
962-
| ((pat, guard, body, annot) as c) :: cs ->
963-
group_aux (remove_wildcards "g__" pat, [c], annot_from_clause annot) [] cs
964-
| _ -> raise (Reporting.err_unreachable l __POS__ "group given empty list in rewrite_guarded_clauses")
965-
in
966-
let add_group cs groups = if_pexp (groups @ fallthrough) cs :: groups in
967-
List.fold_right add_group groups []
968-
and if_pexp fallthrough (pat, cs, annot) =
969-
match cs with
970-
| c :: _ ->
971-
let body = if_exp fallthrough pat cs in
972-
(pat, body, annot)
973-
| [] -> raise (Reporting.err_unreachable l __POS__ "if_pexp given empty list in rewrite_guarded_clauses")
974-
and if_exp fallthrough current_pat = function
975-
| (pat, guard, body, annot) :: ((pat', guard', body', annot') as c') :: cs -> (
976-
match guard with
977-
| Some exp ->
978-
let env = env_of exp in
979-
let else_exp =
980-
if equiv_pats current_pat pat' then if_exp fallthrough current_pat (c' :: cs)
981-
else case_exp (pat_to_exp env current_pat) (typ_of body') (group fallthrough (c' :: cs))
982-
in
983-
annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
984-
| None -> body
985-
)
986-
| [(pat, guard, body, annot)] -> (
987-
(* For singleton clauses with a guard, use fallthrough clauses if the
988-
guard is not satisfied, but only those fallthrough clauses that are
989-
not disjoint with the current pattern *)
990-
let overlapping_clause (pat, _, _) = not (disjoint_pat env current_pat pat) in
991-
let fallthrough = List.filter overlapping_clause fallthrough in
992-
match (guard, fallthrough) with
993-
| Some exp, _ :: _ ->
994-
let env = env_of exp in
995-
let else_exp = case_exp (pat_to_exp env current_pat) (typ_of body) fallthrough in
996-
annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
997-
| _, _ -> body
998-
)
999-
| [] -> raise (Reporting.err_unreachable l __POS__ "if_exp given empty list in rewrite_guarded_clauses")
1000-
in
1001-
1002942
let is_complete =
1003943
pats_complete l env
1004944
(List.map (fun (pat, guard, body, cl_annot) -> construct_pexp (pat, guard, body, annot_from_clause cl_annot)) cs)
@@ -1007,13 +947,79 @@ let rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ
1007947
let fallthrough =
1008948
if not is_complete then [fix_fallthrough (destruct_pexp (mk_fallthrough l env pat_typ typ))] else []
1009949
in
1010-
group [] (cs @ fallthrough)
950+
if fun_only && is_bitvector_typ pat_typ then
951+
List.map (fun (pat, guard, body, annot) -> (pat, guard, body, annot_from_clause annot)) (cs @ fallthrough)
952+
else (
953+
let rec group fallthrough clauses =
954+
let add_clause (pat, cls, annot) c = (pat, cls @ [c], annot) in
955+
let rec group_aux current acc = function
956+
| ((pat, guard, body, annot) as c) :: cs -> (
957+
let current_pat, _, _ = current in
958+
match subsumes_pat current_pat pat with
959+
| Some substs ->
960+
let pat' = List.fold_left subst_id_pat pat substs in
961+
let guard' =
962+
match guard with Some exp -> Some (List.fold_left subst_id_exp exp substs) | None -> None
963+
in
964+
let body' = List.fold_left subst_id_exp body substs in
965+
let c' = (pat', guard', body', annot) in
966+
group_aux (add_clause current c') acc cs
967+
| None ->
968+
let pat = match cs with _ :: _ -> remove_wildcards "g__" pat | _ -> pat in
969+
group_aux (pat, [c], annot_from_clause annot) (acc @ [current]) cs
970+
)
971+
| [] -> acc @ [current]
972+
in
973+
let groups =
974+
match clauses with
975+
| [((pat, guard, body, annot) as c)] -> [(pat, [c], annot_from_clause annot)]
976+
| ((pat, guard, body, annot) as c) :: cs ->
977+
group_aux (remove_wildcards "g__" pat, [c], annot_from_clause annot) [] cs
978+
| _ -> raise (Reporting.err_unreachable l __POS__ "group given empty list in rewrite_guarded_clauses")
979+
in
980+
let add_group cs groups = if_pexp (groups @ fallthrough) cs :: groups in
981+
List.fold_right add_group groups []
982+
and if_pexp fallthrough (pat, cs, annot) =
983+
match cs with
984+
| c :: _ ->
985+
let body = if_exp fallthrough pat cs in
986+
(pat, body, annot)
987+
| [] -> raise (Reporting.err_unreachable l __POS__ "if_pexp given empty list in rewrite_guarded_clauses")
988+
and if_exp fallthrough current_pat = function
989+
| (pat, guard, body, annot) :: ((pat', _, body', _) as c') :: cs -> (
990+
match guard with
991+
| Some exp ->
992+
let env = env_of exp in
993+
let else_exp =
994+
if equiv_pats current_pat pat' then if_exp fallthrough current_pat (c' :: cs)
995+
else case_exp (pat_to_exp env current_pat) (typ_of body') (group fallthrough (c' :: cs))
996+
in
997+
annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
998+
| None -> body
999+
)
1000+
| [(pat, guard, body, annot)] -> (
1001+
(* For singleton clauses with a guard, use fallthrough clauses if the
1002+
guard is not satisfied, but only those fallthrough clauses that are
1003+
not disjoint with the current pattern *)
1004+
let overlapping_clause (pat, _, _) = not (disjoint_pat env current_pat pat) in
1005+
let fallthrough = List.filter overlapping_clause fallthrough in
1006+
match (guard, fallthrough) with
1007+
| Some exp, _ :: _ ->
1008+
let env = env_of exp in
1009+
let else_exp = case_exp (pat_to_exp env current_pat) (typ_of body) fallthrough in
1010+
annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
1011+
| _, _ -> body
1012+
)
1013+
| [] -> raise (Reporting.err_unreachable l __POS__ "if_exp given empty list in rewrite_guarded_clauses")
1014+
in
1015+
List.map (fun (pat, exp, annot) -> (pat, None, exp, annot)) (group [] (cs @ fallthrough))
1016+
)
10111017

1012-
let rewrite_guarded_clauses mk_fallthrough l env pat_typ typ
1018+
let rewrite_guarded_clauses fun_only mk_fallthrough l env pat_typ typ
10131019
(cs : (tannot pat * tannot exp option * tannot exp * tannot annot) list) =
10141020
let map_clause_annot f cs = List.map (fun (pat, guard, body, annot) -> (pat, guard, body, f annot)) cs in
10151021
let cs = map_clause_annot (fun (l, tannot) -> (mk_def_annot l (), tannot)) cs in
1016-
rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ cs
1022+
rewrite_toplevel_guarded_clauses fun_only mk_fallthrough l env pat_typ typ cs
10171023

10181024
let mk_pattern_match_failure_pexp l env pat_typ typ =
10191025
let p = P_aux (P_wild, (gen_loc l, mk_tannot env pat_typ)) in
@@ -1429,7 +1435,7 @@ let rewrite_bit_lists_to_lits env =
14291435

14301436
(* Remove pattern guards by rewriting them to if-expressions within the
14311437
pattern expression. *)
1432-
let rewrite_exp_guarded_pats rewriters (E_aux (exp, (l, annot)) as full_exp) =
1438+
let rewrite_exp_guarded_pats fun_only rewriters (E_aux (exp, (l, annot)) as full_exp) =
14331439
let rewrap e = E_aux (e, (l, annot)) in
14341440
let rewrite_rec = rewriters.rewrite_exp rewriters in
14351441
let rewrite_base = rewrite_exp rewriters in
@@ -1448,8 +1454,11 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp, (l, annot)) as full_exp) =
14481454
| Pat_aux (Pat_when (pat, guard, body), annot) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot)
14491455
in
14501456
let clauses =
1451-
rewrite_guarded_clauses mk_pattern_match_failure_pexp l (env_of full_exp) (typ_of e) (typ_of full_exp)
1452-
(List.map clause ps)
1457+
List.map
1458+
(fun (pat, _, body, annot) -> (pat, body, annot))
1459+
(rewrite_guarded_clauses fun_only mk_pattern_match_failure_pexp l (env_of full_exp) (typ_of e)
1460+
(typ_of full_exp) (List.map clause ps)
1461+
)
14531462
in
14541463
let e = rewrite_rec e in
14551464
if effectful e then (
@@ -1468,14 +1477,15 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp, (l, annot)) as full_exp) =
14681477
| Pat_aux (Pat_when (pat, guard, body), annot) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot)
14691478
in
14701479
let clauses =
1471-
rewrite_guarded_clauses mk_rethrow_pexp l (env_of full_exp) exc_typ (typ_of full_exp) (List.map clause ps)
1480+
rewrite_guarded_clauses fun_only mk_rethrow_pexp l (env_of full_exp) exc_typ (typ_of full_exp)
1481+
(List.map clause ps)
14721482
in
1473-
let pexp (pat, body, annot) = Pat_aux (Pat_exp (pat, body), annot) in
1483+
let pexp (pat, _, body, annot) = Pat_aux (Pat_exp (pat, body), annot) in
14741484
let ps = List.map pexp clauses in
14751485
annot_exp (E_try (e, ps)) l (env_of full_exp) (typ_of full_exp)
14761486
| _ -> rewrite_base full_exp
14771487

1478-
let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r, t, funcls), (l, fdannot))) =
1488+
let rewrite_fun_guarded_pats fun_only rewriters (FD_aux (FD_function (r, t, funcls), (l, fdannot))) =
14791489
let funcls =
14801490
match funcls with
14811491
| FCL_aux (FCL_funcl (id, pexp), fcl_annot) :: _ ->
@@ -1496,14 +1506,14 @@ let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r, t, funcls), (l,
14961506
| exception _ -> (pexp_pat_typ, pexp_ret_typ)
14971507
in
14981508
let cs =
1499-
rewrite_toplevel_guarded_clauses mk_pattern_match_failure_pexp l
1509+
rewrite_toplevel_guarded_clauses fun_only mk_pattern_match_failure_pexp l
15001510
(env_of_tannot (snd fcl_annot))
15011511
pat_typ ret_typ (List.map clause funcls)
15021512
in
15031513
List.map
1504-
(fun (pat, exp, annot) ->
1514+
(fun (pat, guard, exp, annot) ->
15051515
FCL_aux
1506-
( FCL_funcl (id, construct_pexp (pat, None, exp, (Parse_ast.Unknown, empty_tannot))),
1516+
( FCL_funcl (id, construct_pexp (pat, guard, exp, (Parse_ast.Unknown, empty_tannot))),
15071517
(mk_def_annot (fst annot) (), snd annot)
15081518
)
15091519
)
@@ -1514,7 +1524,11 @@ let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r, t, funcls), (l,
15141524

15151525
let rewrite_ast_guarded_pats env =
15161526
rewrite_ast_base
1517-
{ rewriters_base with rewrite_exp = rewrite_exp_guarded_pats; rewrite_fun = rewrite_fun_guarded_pats }
1527+
{ rewriters_base with rewrite_exp = rewrite_exp_guarded_pats false; rewrite_fun = rewrite_fun_guarded_pats false }
1528+
1529+
let rewrite_ast_fun_guarded_pats env =
1530+
rewrite_ast_base
1531+
{ rewriters_base with rewrite_exp = rewrite_exp_guarded_pats true; rewrite_fun = rewrite_fun_guarded_pats true }
15181532

15191533
let rec rewrite_lexp_to_rhs (LE_aux (lexp, ((l, _) as annot)) as le) =
15201534
match lexp with
@@ -4737,6 +4751,7 @@ let all_rewriters =
47374751
("remove_bitvector_pats", basic_rewriter rewrite_ast_remove_bitvector_pats);
47384752
("remove_numeral_pats", basic_rewriter rewrite_ast_remove_numeral_pats);
47394753
("guarded_pats", basic_rewriter rewrite_ast_guarded_pats);
4754+
("fun_guarded_pats", basic_rewriter rewrite_ast_fun_guarded_pats);
47404755
("bit_lists_to_lits", basic_rewriter rewrite_bit_lists_to_lits);
47414756
("exp_lift_assign", basic_rewriter rewrite_ast_exp_lift_assign);
47424757
("early_return", base_rewriter rewrite_ast_early_return);

0 commit comments

Comments
 (0)