Skip to content

Commit 2fb6b37

Browse files
committed
Improve tuple handling in remove_e_assign rewrite
- don't rewrite unnecessarily - add type information where possible
1 parent 14d4827 commit 2fb6b37

7 files changed

+40
-24
lines changed

src/lib/rewrites.ml

+2-2
Original file line numberDiff line numberDiff line change
@@ -2974,7 +2974,7 @@ let rec rewrite_var_updates (E_aux (expaux, ((l, _) as annot)) as exp) =
29742974
we would introduce a new variable rather than using a wildcard and unit literal
29752975
*)
29762976
let is_trivial = function E_aux ((E_id _ | E_lit _), _) -> true | _ -> false in
2977-
if List.for_all is_trivial exps then exp
2977+
if find_updated_vars exp |> IdSet.is_empty then exp
29782978
else (
29792979
let tuple_typ = typ_of exp in
29802980
let typs =
@@ -3004,7 +3004,7 @@ let rec rewrite_var_updates (E_aux (expaux, ((l, _) as annot)) as exp) =
30043004
else (
30053005
let lb =
30063006
if is_unit_typ typ then LB_aux (LB_val (P_aux (P_wild, swaptyp typ annot), exp), annot)
3007-
else LB_aux (LB_val (P_aux (P_id id, swaptyp typ annot), exp), annot)
3007+
else LB_aux (LB_val (add_p_typ env typ (P_aux (P_id id, swaptyp typ annot)), exp), annot)
30083008
in
30093009
E_aux (E_let (lb, tup), annot)
30103010
)

test/lean/match.expected.lean

+3-6
Original file line numberDiff line numberDiff line change
@@ -181,21 +181,18 @@ def match_read (x : E) : SailM Unit := do
181181
| C => readReg r_C)
182182

183183
def const16 (_ : Unit) : ((BitVec 16) × Bool) :=
184-
let t__4 := (0xFFFF : (BitVec 16))
185-
(t__4, true)
184+
((0xFFFF : (BitVec 16)), true)
186185

187186
def const32 (_ : Unit) : ((BitVec 32) × Bool) :=
188-
let t__2 := (0xEEEEEEEE : (BitVec 32))
189-
(t__2, false)
187+
((0xEEEEEEEE : (BitVec 32)), false)
190188

191189
/-- Type quantifiers: k_n : Nat, k_n ≥ 0 -/
192190
def match_width (x : (BitVec k_n)) : (BitVec (2 * k_n)) :=
193191
let (foo, _) : ((BitVec k_n) × Bool) :=
194192
match (Sail.BitVec.length x) with
195193
| 16 => (const16 ())
196194
| 32 => (const32 ())
197-
| n => (let t__0 := (BitVec.zero n)
198-
(t__0, false))
195+
| n => ((BitVec.zero n), false)
199196
(foo ++ foo)
200197

201198
def initialize_registers (_ : Unit) : SailM Unit := do

test/lean/tuples.expected.lean

+5-8
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,21 @@ open Sail
3535

3636
namespace Out.Functions
3737

38-
def let5 := (20, 300000000000000000000000)
38+
def let0 := (20, 300000000000000000000000)
3939

4040
def y :=
41-
let (y, z) := let5
41+
let (y, z) := let0
4242
y
4343

4444
def z :=
45-
let (y, z) := let5
45+
let (y, z) := let0
4646
z
4747

4848
def tuple1 (_ : Unit) : (Int × Int × ((BitVec 2) × Unit)) :=
49-
let t__4 := ((0b10 : (BitVec 2)), ())
50-
(3, 5, t__4)
49+
(3, 5, ((0b10 : (BitVec 2)), ()))
5150

5251
def tuple2 (_ : Unit) : SailM (Int × Int) := do
53-
let t__0 ← do (undefined_int ())
54-
let t__1 ← do (undefined_int ())
55-
(pure (t__0, t__1))
52+
(pure ((← (undefined_int ())), (← (undefined_int ()))))
5653

5754
def initialize_registers (_ : Unit) : Unit :=
5855
()

test/lean/typquant.expected.lean

+1-2
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def hex_bits_signed2_forwards (bv : (BitVec k_nn)) : (Nat × String) :=
158158
bif (BEq.beq (BitVec.access bv (len -i 1)) 1#1)
159159
then "stub1"
160160
else "stub2"
161-
let t__3 := (Sail.BitVec.length bv)
162-
(t__3, s)
161+
((Sail.BitVec.length bv), s)
163162

164163
/-- Type quantifiers: k_nn : Nat, k_nn > 0 -/
165164
def hex_bits_signed2_forwards_matches (bv : (BitVec k_nn)) : Bool :=

test/lean/undefined.expected.lean

+2-6
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@ namespace Out.Functions
3737

3838
/-- Type quantifiers: n : Int -/
3939
def foo (n : Int) : SailM (Bool × (BitVec 1) × Int × Nat × (BitVec 3)) := do
40-
let t__0 ← do (undefined_bool ())
41-
let t__1 ← do (undefined_bit ())
42-
let t__2 ← do (undefined_int ())
43-
let t__3 ← do (undefined_nat ())
44-
let t__4 ← do (undefined_bitvector 3)
45-
(pure (t__0, t__1, t__2, t__3, t__4))
40+
(pure ((← (undefined_bool ())), (← (undefined_bit ())), (← (undefined_int ())), (← (undefined_nat
41+
())), (← (undefined_bitvector 3))))
4642

4743
/-- Type quantifiers: n : Int -/
4844
def bar (n : Int) : SailM (Vector Int 4) := do
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
default Order dec
2+
$include <prelude.sail>
3+
4+
/* The remove_e_assign rewrite was lifting out the subexpressions from the tuple,
5+
but failing to include enough typing information.
6+
7+
Note that this test is really for the Rocq and Lem backends; plain type
8+
checking should be straightforward.
9+
*/
10+
11+
function test(xs : list(int), ys : list(int)) -> (list(int), list(int)) = {
12+
var x : int = 1;
13+
(({x = x + 1; x}) :: xs, 3 :: ys)
14+
}
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
default Order dec
2+
$include <prelude.sail>
3+
4+
/* The remove_e_assign rewrite was lifting out the subexpressions from the tuple,
5+
but failing to include enough typing information. This simple case ought to
6+
be left alone.
7+
8+
Note that this test is really for the Rocq and Lem backends; plain type
9+
checking should be straightforward.
10+
*/
11+
12+
function test(xs : list(int), ys : list(int)) -> (list(int), list(int)) =
13+
(5 :: xs, 3 :: ys)

0 commit comments

Comments
 (0)