@@ -913,92 +913,32 @@ let pats_complete l env ps typ =
913
913
PC. is_complete l ctx ps typ
914
914
915
915
(* Rewrite guarded patterns into a combination of if-expressions and
916
- unguarded pattern matches
917
-
918
- Strategy:
919
- - Split clauses into groups where the first pattern subsumes all the
920
- following ones
921
- - Translate the groups in reverse order, using the next group as a
922
- fall-through target, if there is one
923
- - Within a group,
924
- - translate the sequence of clauses to an if-then-else cascade using the
925
- guards as long as the patterns are equivalent modulo substitution, or
926
- - recursively translate the remaining clauses to a pattern match if
927
- there is a difference in the patterns.
916
+ unguarded pattern matches
917
+
918
+ If [fun_only] is [true], do not rewrite bitvector patterns that are not
919
+ function parameters. The motivation is that the Lean backend can natively
920
+ handle those.
921
+
922
+ Strategy:
923
+ - Split clauses into groups where the first pattern subsumes all the
924
+ following ones
925
+ - Translate the groups in reverse order, using the next group as a
926
+ fall-through target, if there is one
927
+ - Within a group,
928
+ - translate the sequence of clauses to an if-then-else cascade using the
929
+ guards as long as the patterns are equivalent modulo substitution, or
930
+ - recursively translate the remaining clauses to a pattern match if
931
+ there is a difference in the patterns.
928
932
929
933
TODO: Compare this more closely with the algorithm in the CPP'18 paper of
930
934
Spector-Zabusky et al, who seem to use the opposite grouping and merging
931
935
strategy to ours: group *mutually exclusive* clauses, and try to merge them
932
936
into a pattern match first instead of an if-then-else cascade.
933
937
*)
934
- let rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ
938
+ let rewrite_toplevel_guarded_clauses fun_only mk_fallthrough l env pat_typ typ
935
939
(cs : (tannot pat * tannot exp option * tannot exp * tannot clause_annot) list ) =
936
940
let annot_from_clause (def_annot , tannot ) = (def_annot.loc, tannot) in
937
941
let fix_fallthrough (pat , guard , exp , (l , tannot )) = (pat, guard, exp, (mk_def_annot l () , tannot)) in
938
-
939
- let rec group fallthrough clauses =
940
- let add_clause (pat , cls , annot ) c = (pat, cls @ [c], annot) in
941
- let rec group_aux current acc = function
942
- | ((pat , guard , body , annot ) as c ) :: cs -> (
943
- let current_pat, _, _ = current in
944
- match subsumes_pat current_pat pat with
945
- | Some substs ->
946
- let pat' = List. fold_left subst_id_pat pat substs in
947
- let guard' =
948
- match guard with Some exp -> Some (List. fold_left subst_id_exp exp substs) | None -> None
949
- in
950
- let body' = List. fold_left subst_id_exp body substs in
951
- let c' = (pat', guard', body', annot) in
952
- group_aux (add_clause current c') acc cs
953
- | None ->
954
- let pat = match cs with _ :: _ -> remove_wildcards " g__" pat | _ -> pat in
955
- group_aux (pat, [c], annot_from_clause annot) (acc @ [current]) cs
956
- )
957
- | [] -> acc @ [current]
958
- in
959
- let groups =
960
- match clauses with
961
- | [((pat, guard, body, annot) as c)] -> [(pat, [c], annot_from_clause annot)]
962
- | ((pat , guard , body , annot ) as c ) :: cs ->
963
- group_aux (remove_wildcards " g__" pat, [c], annot_from_clause annot) [] cs
964
- | _ -> raise (Reporting. err_unreachable l __POS__ " group given empty list in rewrite_guarded_clauses" )
965
- in
966
- let add_group cs groups = if_pexp (groups @ fallthrough) cs :: groups in
967
- List. fold_right add_group groups []
968
- and if_pexp fallthrough (pat , cs , annot ) =
969
- match cs with
970
- | c :: _ ->
971
- let body = if_exp fallthrough pat cs in
972
- (pat, body, annot)
973
- | [] -> raise (Reporting. err_unreachable l __POS__ " if_pexp given empty list in rewrite_guarded_clauses" )
974
- and if_exp fallthrough current_pat = function
975
- | (pat , guard , body , annot ) :: ((pat' , guard' , body' , annot' ) as c' ) :: cs -> (
976
- match guard with
977
- | Some exp ->
978
- let env = env_of exp in
979
- let else_exp =
980
- if equiv_pats current_pat pat' then if_exp fallthrough current_pat (c' :: cs)
981
- else case_exp (pat_to_exp env current_pat) (typ_of body') (group fallthrough (c' :: cs))
982
- in
983
- annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
984
- | None -> body
985
- )
986
- | [(pat, guard, body, annot)] -> (
987
- (* For singleton clauses with a guard, use fallthrough clauses if the
988
- guard is not satisfied, but only those fallthrough clauses that are
989
- not disjoint with the current pattern *)
990
- let overlapping_clause (pat , _ , _ ) = not (disjoint_pat env current_pat pat) in
991
- let fallthrough = List. filter overlapping_clause fallthrough in
992
- match (guard, fallthrough) with
993
- | Some exp , _ :: _ ->
994
- let env = env_of exp in
995
- let else_exp = case_exp (pat_to_exp env current_pat) (typ_of body) fallthrough in
996
- annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
997
- | _ , _ -> body
998
- )
999
- | [] -> raise (Reporting. err_unreachable l __POS__ " if_exp given empty list in rewrite_guarded_clauses" )
1000
- in
1001
-
1002
942
let is_complete =
1003
943
pats_complete l env
1004
944
(List. map (fun (pat , guard , body , cl_annot ) -> construct_pexp (pat, guard, body, annot_from_clause cl_annot)) cs)
@@ -1007,13 +947,79 @@ let rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ
1007
947
let fallthrough =
1008
948
if not is_complete then [fix_fallthrough (destruct_pexp (mk_fallthrough l env pat_typ typ))] else []
1009
949
in
1010
- group [] (cs @ fallthrough)
950
+ if fun_only && is_bitvector_typ pat_typ then
951
+ List. map (fun (pat , guard , body , annot ) -> (pat, guard, body, annot_from_clause annot)) (cs @ fallthrough)
952
+ else (
953
+ let rec group fallthrough clauses =
954
+ let add_clause (pat , cls , annot ) c = (pat, cls @ [c], annot) in
955
+ let rec group_aux current acc = function
956
+ | ((pat , guard , body , annot ) as c ) :: cs -> (
957
+ let current_pat, _, _ = current in
958
+ match subsumes_pat current_pat pat with
959
+ | Some substs ->
960
+ let pat' = List. fold_left subst_id_pat pat substs in
961
+ let guard' =
962
+ match guard with Some exp -> Some (List. fold_left subst_id_exp exp substs) | None -> None
963
+ in
964
+ let body' = List. fold_left subst_id_exp body substs in
965
+ let c' = (pat', guard', body', annot) in
966
+ group_aux (add_clause current c') acc cs
967
+ | None ->
968
+ let pat = match cs with _ :: _ -> remove_wildcards " g__" pat | _ -> pat in
969
+ group_aux (pat, [c], annot_from_clause annot) (acc @ [current]) cs
970
+ )
971
+ | [] -> acc @ [current]
972
+ in
973
+ let groups =
974
+ match clauses with
975
+ | [((pat, guard, body, annot) as c)] -> [(pat, [c], annot_from_clause annot)]
976
+ | ((pat , guard , body , annot ) as c ) :: cs ->
977
+ group_aux (remove_wildcards " g__" pat, [c], annot_from_clause annot) [] cs
978
+ | _ -> raise (Reporting. err_unreachable l __POS__ " group given empty list in rewrite_guarded_clauses" )
979
+ in
980
+ let add_group cs groups = if_pexp (groups @ fallthrough) cs :: groups in
981
+ List. fold_right add_group groups []
982
+ and if_pexp fallthrough (pat , cs , annot ) =
983
+ match cs with
984
+ | c :: _ ->
985
+ let body = if_exp fallthrough pat cs in
986
+ (pat, body, annot)
987
+ | [] -> raise (Reporting. err_unreachable l __POS__ " if_pexp given empty list in rewrite_guarded_clauses" )
988
+ and if_exp fallthrough current_pat = function
989
+ | (pat , guard , body , annot ) :: ((pat' , _ , body' , _ ) as c' ) :: cs -> (
990
+ match guard with
991
+ | Some exp ->
992
+ let env = env_of exp in
993
+ let else_exp =
994
+ if equiv_pats current_pat pat' then if_exp fallthrough current_pat (c' :: cs)
995
+ else case_exp (pat_to_exp env current_pat) (typ_of body') (group fallthrough (c' :: cs))
996
+ in
997
+ annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
998
+ | None -> body
999
+ )
1000
+ | [(pat, guard, body, annot)] -> (
1001
+ (* For singleton clauses with a guard, use fallthrough clauses if the
1002
+ guard is not satisfied, but only those fallthrough clauses that are
1003
+ not disjoint with the current pattern *)
1004
+ let overlapping_clause (pat , _ , _ ) = not (disjoint_pat env current_pat pat) in
1005
+ let fallthrough = List. filter overlapping_clause fallthrough in
1006
+ match (guard, fallthrough) with
1007
+ | Some exp , _ :: _ ->
1008
+ let env = env_of exp in
1009
+ let else_exp = case_exp (pat_to_exp env current_pat) (typ_of body) fallthrough in
1010
+ annot_exp (E_if (exp, body, else_exp)) (fst annot).loc env (typ_of body)
1011
+ | _ , _ -> body
1012
+ )
1013
+ | [] -> raise (Reporting. err_unreachable l __POS__ " if_exp given empty list in rewrite_guarded_clauses" )
1014
+ in
1015
+ List. map (fun (pat , exp , annot ) -> (pat, None , exp, annot)) (group [] (cs @ fallthrough))
1016
+ )
1011
1017
1012
- let rewrite_guarded_clauses mk_fallthrough l env pat_typ typ
1018
+ let rewrite_guarded_clauses fun_only mk_fallthrough l env pat_typ typ
1013
1019
(cs : (tannot pat * tannot exp option * tannot exp * tannot annot) list ) =
1014
1020
let map_clause_annot f cs = List. map (fun (pat , guard , body , annot ) -> (pat, guard, body, f annot)) cs in
1015
1021
let cs = map_clause_annot (fun (l , tannot ) -> (mk_def_annot l () , tannot)) cs in
1016
- rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ cs
1022
+ rewrite_toplevel_guarded_clauses fun_only mk_fallthrough l env pat_typ typ cs
1017
1023
1018
1024
let mk_pattern_match_failure_pexp l env pat_typ typ =
1019
1025
let p = P_aux (P_wild , (gen_loc l, mk_tannot env pat_typ)) in
@@ -1429,7 +1435,7 @@ let rewrite_bit_lists_to_lits env =
1429
1435
1430
1436
(* Remove pattern guards by rewriting them to if-expressions within the
1431
1437
pattern expression. *)
1432
- let rewrite_exp_guarded_pats rewriters (E_aux (exp , (l , annot )) as full_exp ) =
1438
+ let rewrite_exp_guarded_pats fun_only rewriters (E_aux (exp , (l , annot )) as full_exp ) =
1433
1439
let rewrap e = E_aux (e, (l, annot)) in
1434
1440
let rewrite_rec = rewriters.rewrite_exp rewriters in
1435
1441
let rewrite_base = rewrite_exp rewriters in
@@ -1448,8 +1454,11 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp, (l, annot)) as full_exp) =
1448
1454
| Pat_aux (Pat_when (pat , guard , body ), annot ) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot)
1449
1455
in
1450
1456
let clauses =
1451
- rewrite_guarded_clauses mk_pattern_match_failure_pexp l (env_of full_exp) (typ_of e) (typ_of full_exp)
1452
- (List. map clause ps)
1457
+ List. map
1458
+ (fun (pat , _ , body , annot ) -> (pat, body, annot))
1459
+ (rewrite_guarded_clauses fun_only mk_pattern_match_failure_pexp l (env_of full_exp) (typ_of e)
1460
+ (typ_of full_exp) (List. map clause ps)
1461
+ )
1453
1462
in
1454
1463
let e = rewrite_rec e in
1455
1464
if effectful e then (
@@ -1468,14 +1477,15 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp, (l, annot)) as full_exp) =
1468
1477
| Pat_aux (Pat_when (pat , guard , body ), annot ) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot)
1469
1478
in
1470
1479
let clauses =
1471
- rewrite_guarded_clauses mk_rethrow_pexp l (env_of full_exp) exc_typ (typ_of full_exp) (List. map clause ps)
1480
+ rewrite_guarded_clauses fun_only mk_rethrow_pexp l (env_of full_exp) exc_typ (typ_of full_exp)
1481
+ (List. map clause ps)
1472
1482
in
1473
- let pexp (pat , body , annot ) = Pat_aux (Pat_exp (pat, body), annot) in
1483
+ let pexp (pat , _ , body , annot ) = Pat_aux (Pat_exp (pat, body), annot) in
1474
1484
let ps = List. map pexp clauses in
1475
1485
annot_exp (E_try (e, ps)) l (env_of full_exp) (typ_of full_exp)
1476
1486
| _ -> rewrite_base full_exp
1477
1487
1478
- let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r , t , funcls ), (l , fdannot ))) =
1488
+ let rewrite_fun_guarded_pats fun_only rewriters (FD_aux (FD_function (r , t , funcls ), (l , fdannot ))) =
1479
1489
let funcls =
1480
1490
match funcls with
1481
1491
| FCL_aux (FCL_funcl (id , pexp ), fcl_annot ) :: _ ->
@@ -1496,14 +1506,14 @@ let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r, t, funcls), (l,
1496
1506
| exception _ -> (pexp_pat_typ, pexp_ret_typ)
1497
1507
in
1498
1508
let cs =
1499
- rewrite_toplevel_guarded_clauses mk_pattern_match_failure_pexp l
1509
+ rewrite_toplevel_guarded_clauses fun_only mk_pattern_match_failure_pexp l
1500
1510
(env_of_tannot (snd fcl_annot))
1501
1511
pat_typ ret_typ (List. map clause funcls)
1502
1512
in
1503
1513
List. map
1504
- (fun (pat , exp , annot ) ->
1514
+ (fun (pat , guard , exp , annot ) ->
1505
1515
FCL_aux
1506
- ( FCL_funcl (id, construct_pexp (pat, None , exp, (Parse_ast. Unknown , empty_tannot))),
1516
+ ( FCL_funcl (id, construct_pexp (pat, guard , exp, (Parse_ast. Unknown , empty_tannot))),
1507
1517
(mk_def_annot (fst annot) () , snd annot)
1508
1518
)
1509
1519
)
@@ -1514,7 +1524,11 @@ let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r, t, funcls), (l,
1514
1524
1515
1525
let rewrite_ast_guarded_pats env =
1516
1526
rewrite_ast_base
1517
- { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats; rewrite_fun = rewrite_fun_guarded_pats }
1527
+ { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats false ; rewrite_fun = rewrite_fun_guarded_pats false }
1528
+
1529
+ let rewrite_ast_fun_guarded_pats env =
1530
+ rewrite_ast_base
1531
+ { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats true ; rewrite_fun = rewrite_fun_guarded_pats true }
1518
1532
1519
1533
let rec rewrite_lexp_to_rhs (LE_aux (lexp , ((l , _ ) as annot )) as le ) =
1520
1534
match lexp with
@@ -4737,6 +4751,7 @@ let all_rewriters =
4737
4751
(" remove_bitvector_pats" , basic_rewriter rewrite_ast_remove_bitvector_pats);
4738
4752
(" remove_numeral_pats" , basic_rewriter rewrite_ast_remove_numeral_pats);
4739
4753
(" guarded_pats" , basic_rewriter rewrite_ast_guarded_pats);
4754
+ (" fun_guarded_pats" , basic_rewriter rewrite_ast_fun_guarded_pats);
4740
4755
(" bit_lists_to_lits" , basic_rewriter rewrite_bit_lists_to_lits);
4741
4756
(" exp_lift_assign" , basic_rewriter rewrite_ast_exp_lift_assign);
4742
4757
(" early_return" , base_rewriter rewrite_ast_early_return);
0 commit comments