Skip to content

Commit 4462105

Browse files
committed
✨ No more special pred type
1 parent 0f20467 commit 4462105

File tree

3 files changed

+65
-67
lines changed

3 files changed

+65
-67
lines changed

lib/compiler.ml

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
2525
match peval te with
2626
| { exp = Value v; _ } -> { ty; exp = Value (uop.op v) }
2727
| e -> { ty; exp = Uop (uop, e) })
28-
| If_pred (pred, te_con, te_alt) -> (
29-
match peval_pred pred with
30-
| True -> peval { ty; exp = If_just te_con }
31-
| False -> peval { ty; exp = If_just te_alt }
28+
| If_pred (te_pred, te_con, te_alt) -> (
29+
match peval te_pred with
30+
| { exp = Value true; _ } -> peval { ty; exp = If_just te_con }
31+
| { exp = Value false; _ } -> peval { ty; exp = If_just te_alt }
3232
| p -> { ty; exp = If_pred (p, peval te_con, peval te_alt) })
3333
| Call (f, args) -> (
3434
match peval_args args with
@@ -48,9 +48,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
4848
in
4949
{ ty; exp = Call (f_dist, []) })
5050
| If_pred_dist (p, de) -> (
51-
match peval_pred p with
52-
| True -> peval de
53-
| False -> { ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
51+
match peval p with
52+
| { exp = Value true; _ } -> peval de
53+
| { exp = Value false; _ } ->
54+
{ ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
5455
| p -> { ty; exp = If_pred_dist (p, peval de) })
5556
| If_just de -> { ty; exp = If_just (peval de) }
5657

@@ -63,23 +64,26 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
6364
({ ty; exp = Value v } :: tl, Some ((dty_of_dat_ty ty, v) :: vargs))
6465
| te, (tl, _) -> (te :: tl, None))
6566

66-
and peval_pred : pred -> pred = function
67-
| Empty -> failwith "[Bug] Empty predicate"
68-
| True -> True
69-
| False -> False
70-
| And (p, de) -> (
71-
match peval de with
72-
| { exp = Value true; _ } -> peval_pred p
73-
| { exp = Value false; _ } -> False
74-
| de -> And (p, de))
75-
| And_not (p, de) -> (
76-
match peval de with
77-
| { exp = Value true; _ } -> False
78-
| { exp = Value false; _ } -> peval_pred p
79-
| de -> And_not (p, de))
67+
let ( &&& ) :
68+
type s1 s2 s.
69+
((bool, _) dat_ty, det) texp ->
70+
((bool, _) dat_ty, det) texp ->
71+
bool some_dat_det_texp =
72+
fun ({ ty = Dat_ty (Tyb, s1); _ } as p1) ({ ty = Dat_ty (Tyb, s2); _ } as p2) ->
73+
let (Ex (ms, s)) = merge_stamps s1 s2 in
74+
Ex
75+
(peval
76+
{
77+
ty = Dat_ty (Tyb, s);
78+
exp = Bop ({ name = "&&"; op = ( && ) }, p1, p2, ms);
79+
})
8080

81-
let ( &&& ) p de = peval_pred (And (p, de))
82-
let ( &&! ) p de = peval_pred (And_not (p, de))
81+
let ( &&! ) :
82+
type s1 s2 s.
83+
((bool, _) dat_ty, det) texp ->
84+
((bool, _) dat_ty, det) texp ->
85+
bool some_dat_det_texp =
86+
fun p1 p2 -> p1 &&& { ty = p2.ty; exp = Uop ({ name = "not"; op = not }, p2) }
8387

8488
let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
8589
function
@@ -88,9 +92,12 @@ let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
8892
| { exp = Call _; _ } as e -> e
8993

9094
let rec compile :
91-
type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp
92-
=
93-
fun ~env ?(pred = Empty) { ty; exp } ->
95+
type a s.
96+
env:env ->
97+
pred:((bool, s) dat_ty, det) texp ->
98+
(a, ndet) texp ->
99+
Graph.t * (a, det) texp =
100+
fun ~env ~pred { ty; exp } ->
94101
match exp with
95102
| Value _ as exp -> (Graph.empty, { ty; exp })
96103
| Var x -> (
@@ -107,8 +114,8 @@ let rec compile :
107114
(g, peval { ty; exp = Uop (op, te) })
108115
| If (e_pred, e_con, e_alt, _, _) ->
109116
let g1, de_pred = compile ~env ~pred e_pred in
110-
let pred_con = pred &&& de_pred in
111-
let pred_alt = pred &&! de_pred in
117+
let (Ex pred_con) = pred &&& de_pred in
118+
let (Ex pred_alt) = pred &&! de_pred in
112119
let g2, de_con = compile ~env ~pred:pred_con e_con in
113120
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
114121
let g = Graph.(g1 @| g2 @| g3) in
@@ -143,7 +150,7 @@ let rec compile :
143150
let v = gen_vertex () in
144151
let f1 = score de1 in
145152
let f = { ty = f1.ty; exp = If_pred_dist (pred, f1) } in
146-
let fvs = Id.(fv de1.exp @| fv_pred pred) in
153+
let fvs = Id.(fv de1.exp @| fv pred.exp) in
147154
if not (Set.is_empty (fv de2.exp)) then
148155
failwith "[Bug] Not closed observation";
149156
let g' =
@@ -158,7 +165,11 @@ let rec compile :
158165
Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () })
159166

160167
and compile_args :
161-
type a. env -> pred -> (a, ndet) args -> Graph.t * (a, det) args =
168+
type a s.
169+
env ->
170+
((bool, s) dat_ty, det) texp ->
171+
(a, ndet) args ->
172+
Graph.t * (a, det) args =
162173
fun env pred args ->
163174
match args with
164175
| [] -> (Graph.empty, [])
@@ -177,7 +188,11 @@ let compile_program (prog : program) : Graph.t * Evaluator.query =
177188
m "Inlined program %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]);
178189

179190
let (Ex e) = Typing.check exp in
180-
let g, { ty; exp } = compile ~env:Id.Map.empty e in
191+
let g, { ty; exp } =
192+
compile ~env:Id.Map.empty
193+
~pred:{ ty = Dat_ty (Tyb, Val); exp = Value true }
194+
e
195+
in
181196
match ty with
182197
| Dat_ty (_, Rv) -> (g, Ex { ty; exp })
183198
| _ -> raise Query_not_found

lib/evaluator.ml

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,18 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a =
2222
| None -> assert false)
2323
| Bop ({ op; _ }, te1, te2, _) -> op (eval_dat ctx te1) (eval_dat ctx te2)
2424
| Uop ({ op; _ }, te) -> op (eval_dat ctx te)
25-
| If_pred (pred, te_con, te_alt) ->
26-
if eval_pred ctx pred then eval_dat ctx te_con else eval_dat ctx te_alt
25+
| If_pred (te_pred, te_con, te_alt) ->
26+
if eval_dat ctx te_pred then eval_dat ctx te_con else eval_dat ctx te_alt
2727
| If_just te -> eval_dat ctx te
2828

2929
and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a =
3030
fun ctx { ty = Dist_ty dty as ty; exp } ->
3131
match exp with
3232
| Call (f, args) -> f.sampler (eval_args ctx args)
3333
| If_pred_dist (pred, dist) ->
34-
if eval_pred ctx pred then eval_dist ctx dist
34+
if eval_dat ctx pred then eval_dist ctx dist
3535
else eval_dist ctx { ty; exp = Call (Dist.one dty, []) }
3636

37-
and eval_pred (ctx : Ctx.t) : pred -> bool = function
38-
| Empty | True -> true
39-
| False -> false
40-
| And (p, de) -> eval_dat ctx de && eval_pred ctx p
41-
| And_not (p, de) -> (not (eval_dat ctx de)) && eval_pred ctx p
42-
4337
and eval_args : type a. Ctx.t -> (a, det) args -> a vargs =
4438
fun ctx -> function
4539
| [] -> []
@@ -51,7 +45,7 @@ let rec eval_pmdf :
5145
fun ctx { ty = Dist_ty dty as ty; exp } ->
5246
match exp with
5347
| If_pred_dist (pred, te) ->
54-
if eval_pred ctx pred then eval_pmdf ctx te
48+
if eval_dat ctx pred then eval_pmdf ctx te
5549
else eval_pmdf ctx { ty; exp = Call (Dist.one dty, []) }
5650
| Call (f, args) ->
5751
let pmdf (Ex (ty', v) : some_val) =

lib/typed_tree.ml

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,10 @@ type ('a, 'b) dist = {
4040
log_pmdf : 'b vargs -> 'a -> real;
4141
}
4242

43-
(* TODO: Why args should also be det? *)
4443
type (_, _) args =
4544
| [] : (unit, _) args
4645
| ( :: ) : (('a, _) dat_ty, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args
4746

48-
and pred =
49-
| Empty : pred
50-
| True : pred
51-
| False : pred
52-
| And : pred * ((bool, _) dat_ty, det) texp -> pred
53-
| And_not : pred * ((bool, _) dat_ty, det) texp -> pred
54-
5547
and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp }
5648

5749
and (_, _) exp =
@@ -73,9 +65,13 @@ and (_, _) exp =
7365
* ('s_pred, 's_ca, 's) merge_stamp
7466
-> (('a, 's) dat_ty, ndet) exp
7567
| If_pred :
76-
pred * (('a, _) dat_ty, det) texp * (('a, _) dat_ty, det) texp
68+
((bool, _) dat_ty, det) texp
69+
* (('a, _) dat_ty, det) texp
70+
* (('a, _) dat_ty, det) texp
7771
-> (('a, _) dat_ty, det) exp
78-
| If_pred_dist : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp
72+
| If_pred_dist :
73+
((bool, _) dat_ty, det) texp * ('a dist_ty, det) texp
74+
-> ('a dist_ty, det) exp
7975
| If_just : (('a, _) dat_ty, det) texp -> (('a, _) dat_ty, det) exp
8076
| Let : Id.t * ('a, ndet) texp * ('b, ndet) texp -> ('b, ndet) exp
8177
| Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp
@@ -92,6 +88,9 @@ type _ some_texp = Ex : (_, 'd) texp -> 'd some_texp
9288
type _ some_dat_ndet_texp =
9389
| Ex : (('a, _) dat_ty, ndet) texp -> 'a some_dat_ndet_texp
9490

91+
type _ some_dat_det_texp =
92+
| Ex : (('a, _) dat_ty, det) texp -> 'a some_dat_det_texp
93+
9594
type _ some_val_texp = Ex : ((_, value) dat_ty, 'd) texp -> 'd some_val_texp
9695
type _ some_rv_texp = Ex : ((_, rv) dat_ty, 'd) texp -> 'd some_rv_texp
9796
type _ some_dat_texp = Ex : (_ dat_ty, 'd) texp -> 'd some_dat_texp
@@ -184,21 +183,17 @@ let rec fv : type a. (a, det) exp -> Id.Set.t = function
184183
| Rvar x -> Id.Set.singleton x
185184
| Bop (_, { exp = e1; _ }, { exp = e2; _ }, _) -> Id.(fv e1 @| fv e2)
186185
| Uop (_, { exp; _ }) -> fv exp
187-
| If_pred (pred, { exp = e_con; _ }, { exp = e_alt; _ }) ->
188-
Id.(fv_pred pred @| fv e_con @| fv e_alt)
189-
| If_pred_dist (pred, { exp = e_con; _ }) -> Id.(fv_pred pred @| fv e_con)
186+
| If_pred ({ exp = e_pred; _ }, { exp = e_con; _ }, { exp = e_alt; _ }) ->
187+
Id.(fv e_pred @| fv e_con @| fv e_alt)
188+
| If_pred_dist ({ exp = e_pred; _ }, { exp = e_con; _ }) ->
189+
Id.(fv e_pred @| fv e_con)
190190
| If_just { exp; _ } -> fv exp
191191
| Call (_, args) -> fv_args args
192192

193193
and fv_args : type a. (a, det) args -> Id.Set.t = function
194194
| [] -> Id.Set.empty
195195
| { exp; _ } :: es -> Id.(fv exp @| fv_args es)
196196

197-
and fv_pred : pred -> Id.Set.t = function
198-
| Empty | True | False -> Id.Set.empty
199-
| And (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p)
200-
| And_not (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p)
201-
202197
module Erased = struct
203198
type exp =
204199
| Value : string -> exp
@@ -220,8 +215,8 @@ module Erased = struct
220215
fun { ty; exp } ->
221216
match exp with
222217
| If (pred, con, alt, _, _) -> If (of_exp pred, of_exp con, of_exp alt)
223-
| If_pred (pred, con, alt) -> If (of_pred pred, of_exp con, of_exp alt)
224-
| If_pred_dist (pred, con) -> If (of_pred pred, of_exp con, Value "1")
218+
| If_pred (pred, con, alt) -> If (of_exp pred, of_exp con, of_exp alt)
219+
| If_pred_dist (pred, con) -> If (of_exp pred, of_exp con, Value "1")
225220
| If_just exp -> If_just (of_exp exp)
226221
| Value v -> (
227222
match ty with
@@ -241,11 +236,5 @@ module Erased = struct
241236
| [] -> []
242237
| arg :: args -> of_exp arg :: of_args args
243238

244-
and of_pred : pred -> exp = function
245-
| Empty | True -> Value "true"
246-
| False -> Value "false"
247-
| And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp)
248-
| And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp))
249-
250239
let of_rv (Ex rv : _ some_rv_texp) = rv |> of_exp
251240
end

0 commit comments

Comments
 (0)