Skip to content

Commit feedf34

Browse files
committed
coupling rule: one-sided and two-sided
1 parent 23cf445 commit feedf34

File tree

9 files changed

+258
-26
lines changed

9 files changed

+258
-26
lines changed

src/ecHiTacticals.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ and process1_phl (_ : ttenv) (t : phltactic located) (tc : tcenv1) =
200200
| Psetmatch info -> EcPhlCodeTx.process_set_match info
201201
| Pweakmem info -> EcPhlCodeTx.process_weakmem info
202202
| Prnd (side, pos, info) -> EcPhlRnd.process_rnd side pos info
203+
| Pcoupling (side, f) -> EcPhlRnd.process_coupling side f
203204
| Prndsem (red, side, pos) -> EcPhlRnd.process_rndsem ~reduce:red side pos
204205
| Pconseq (opt, info) -> EcPhlConseq.process_conseq_opt opt info
205206
| Pconseqauto cm -> process_conseqauto cm

src/ecLexer.mll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
"cfold" , CFOLD ; (* KW: tactic *)
146146
"rnd" , RND ; (* KW: tactic *)
147147
"rndsem" , RNDSEM ; (* KW: tactic *)
148+
"coupling" , COUPLING ; (* KW: tactic *)
148149
"pr_bounded" , PRBOUNDED ; (* KW: tactic *)
149150
"bypr" , BYPR ; (* KW: tactic *)
150151
"byphoare" , BYPHOARE ; (* KW: tactic *)

src/ecParser.mly

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@
416416
%token CONSEQ
417417
%token CONST
418418
%token COQ
419+
%token COUPLING
419420
%token CHECK
420421
%token EDIT
421422
%token FIX
@@ -3000,9 +3001,13 @@ interleave_info:
30003001
| RND s=side? info=rnd_info c=prefix(COLON, semrndpos)?
30013002
{ Prnd (s, c, info) }
30023003

3004+
30033005
| RNDSEM red=boption(STAR) s=side? c=codepos1
30043006
{ Prndsem (red, s, c) }
30053007

3008+
| COUPLING s=side? f=sform
3009+
{ Pcoupling (s, f) }
3010+
30063011
| INLINE s=side? u=inlineopt? o=occurences?
30073012
{ Pinline (`ByName(s, u, ([], o))) }
30083013

src/ecParsetree.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,7 @@ type phltactic =
748748
| Pasgncase of (oside * pcodepos)
749749
| Prnd of oside * psemrndpos option * rnd_tac_info_f
750750
| Prndsem of bool * oside * pcodepos1
751+
| Pcoupling of oside * pformula
751752
| Palias of (oside * pcodepos * osymbol_r)
752753
| Pweakmem of (oside * psymbol * fun_params)
753754
| Pset of (oside * pcodepos * bool * psymbol * pexpr)

src/phl/ecPhlPrRw.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,4 +339,4 @@ let t_pr_rewrite (s, f) tc =
339339
let hyps = LDecl.push_active_ss mp hyps in
340340
{m;inv=EcProofTyping.process_form hyps f ty}
341341
in
342-
t_pr_rewrite_low (s, omap to_env f) tc
342+
t_pr_rewrite_low (s, omap to_env f) tc

src/phl/ecPhlRnd.ml

Lines changed: 125 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ open EcSubst
1010

1111
open EcMatching.Position
1212
open EcCoreGoal
13+
open EcCoreFol
1314
open EcLowGoal
1415
open EcLowPhlGoal
1516

@@ -442,6 +443,84 @@ module Core = struct
442443
| `Right -> f_equivS (snd es.es_ml) mt (es_pr es) es.es_sl s (es_po es) in
443444
FApi.xmutate1 tc (`RndSem pos) [concl]
444445

446+
(* -------------------------------------------------------------------- *)
447+
let t_equiv_coupling_r side g tc =
448+
(* process the following pRHL goal, where g is a coupling of g1 and g2 *)
449+
(* {phi} c1; x <$ g1 ~ c2; y <$ g2 {psi} *)
450+
let env = FApi.tc1_env tc in
451+
let es = tc1_as_equivS tc in
452+
let ml = fst es.es_ml in
453+
let mr = fst es.es_mr in
454+
let (lvL, muL), sl' = tc1_last_rnd tc es.es_sl in
455+
let (lvR, muR), sr' = tc1_last_rnd tc es.es_sr in
456+
let tyL = proj_distr_ty env (e_ty muL) in
457+
let tyR = proj_distr_ty env (e_ty muR) in
458+
let muL = ss_inv_generalize_right (EcFol.ss_inv_of_expr ml muL) mr in
459+
let muR = ss_inv_generalize_left (EcFol.ss_inv_of_expr mr muR) ml in
460+
461+
let goal =
462+
match side with
463+
| None ->
464+
(* Goal: {phi} c1 ~ c2 {iscoupling g muL muR /\ (forall a b, (a, b) \in supp(g) => psi[x -> a, y -> b])} *)
465+
(* Generate two free variables a and b and the pair (a, b) *)
466+
let a_id = EcIdent.create "a" in
467+
let b_id = EcIdent.create "b" in
468+
let a = f_local a_id tyL in
469+
let b = f_local b_id tyR in
470+
let ab = f_tuple [a; b] in
471+
472+
(* Generate the coupling distribution type: (tyL * tyR) distr *)
473+
let coupling_ty = ttuple [tyL; tyR] in
474+
let g_app = map_ts_inv1 (fun g -> f_app_simpl g [] (tdistr coupling_ty)) g in
475+
476+
let iscoupling_op = EcPath.extend EcCoreLib.p_top ["Distr"; "iscoupling"] in
477+
let iscoupling_ty = tfun (tdistr tyL) (tfun (tdistr tyR) (tfun (tdistr coupling_ty) tbool)) in
478+
let iscoupling_pred = map_ts_inv (fun f_list ->
479+
f_app
480+
(f_op iscoupling_op [tyL; tyR] iscoupling_ty)
481+
f_list
482+
tbool)
483+
[muL; muR; g_app] in
484+
485+
(* Substitute in the postcondition *)
486+
let post = (es_po es) in
487+
let post_subst = subst_form_lv_left env lvL {ml=ml;mr=mr;inv=a} post in
488+
let post_subst = subst_form_lv_right env lvR {ml=ml;mr=mr;inv=b} post_subst in
489+
490+
let goal = map_ts_inv2 f_imp (map_ts_inv1 (f_in_supp ab) g_app) post_subst in
491+
let goal = map_ts_inv1 (f_forall_simpl [(a_id, GTty tyL); (b_id, GTty tyR)]) goal in
492+
map_ts_inv2 f_and iscoupling_pred goal
493+
| Some side ->
494+
(* Goal (left): {phi} c1 ~ c2 {dmap d1 g = d2 /\ forall a b, b = g(a) => psi[x -> a, y -> b]} *)
495+
(* Goal (right): {phi} c1 ~ c2 {dmap d1 g = d2 /\ forall a b, a = g(b) => psi[x -> a, y -> b]} *)
496+
let dmap_op = EcPath.extend EcCoreLib.p_top ["Distr"; "dmap"] in
497+
let dmap_ty = tfun (tdistr tyL) (tfun (tfun tyL tyR) (tdistr tyR)) in
498+
let dmap_pred = map_ts_inv2 f_eq
499+
(map_ts_inv (fun f_list ->
500+
f_app (f_op dmap_op [tyL; tyR] dmap_ty) f_list (tdistr tyR))
501+
[muL; g])
502+
muR in
503+
504+
let a_id = EcIdent.create "a" in
505+
let b_id = EcIdent.create "b" in
506+
let a = f_local a_id tyL in
507+
let b = f_local b_id tyR in
508+
let post = (es_po es) in
509+
let post_subst = subst_form_lv_left env lvL {ml=ml;mr=mr;inv=a} post in
510+
let post_subst = subst_form_lv_right env lvR {ml=ml;mr=mr;inv=b} post_subst in
511+
512+
let eq_condition =
513+
match side with
514+
| `Left -> map_ts_inv1 (fun g -> f_eq (f_app_simpl g [a] tyR) b) g
515+
| `Right -> map_ts_inv1 (fun g -> f_eq a (f_app_simpl g [b] tyL)) g in
516+
517+
let goal = map_ts_inv2 f_imp eq_condition post_subst in
518+
let goal = map_ts_inv1 (f_forall_simpl [(a_id, GTty tyL); (b_id, GTty tyR)]) goal in
519+
map_ts_inv2 f_and dmap_pred goal
520+
in
521+
let goal = f_equivS (snd es.es_ml) (snd es.es_mr) (es_pr es) sl' sr' goal in
522+
523+
FApi.xmutate1 tc `Rnd [goal]
445524
end (* Core *)
446525

447526
(* -------------------------------------------------------------------- *)
@@ -666,32 +745,32 @@ let process_rnd side pos tac_info tc =
666745
t_bdhoare_rnd tac_info tc
667746

668747
| _, _, _ when is_equivS concl ->
669-
let process_form f ty1 ty2 =
670-
TTC.tc1_process_prhl_form tc (tfun ty1 ty2) f in
671-
672-
let bij_info =
673-
match tac_info with
674-
| PNoRndParams -> None, None
675-
| PSingleRndParam f -> Some (process_form f), None
676-
| PTwoRndParams (f, finv) -> Some (process_form f), Some (process_form finv)
677-
| _ -> tc_error !!tc "invalid arguments"
678-
in
748+
let process_form f ty1 ty2 =
749+
TTC.tc1_process_prhl_form tc (tfun ty1 ty2) f in
750+
751+
let bij_info =
752+
match tac_info with
753+
| PNoRndParams -> None, None
754+
| PSingleRndParam f -> Some (process_form f), None
755+
| PTwoRndParams (f, finv) -> Some (process_form f), Some (process_form finv)
756+
| _ -> tc_error !!tc "invalid arguments"
757+
in
679758

680-
let pos = pos |> Option.map (function
681-
| Single (b, p) ->
682-
let p =
683-
if Option.is_some side then
684-
EcProofTyping.tc1_process_codepos1 tc (side, p)
685-
else EcTyping.trans_codepos1 (FApi.tc1_env tc) p
686-
in Single (b, p)
687-
| Double ((b1, p1), (b2, p2)) ->
688-
let p1 = EcProofTyping.tc1_process_codepos1 tc (Some `Left , p1) in
689-
let p2 = EcProofTyping.tc1_process_codepos1 tc (Some `Right, p2) in
690-
Double ((b1, p1), (b2, p2))
691-
)
692-
in
693-
694-
t_equiv_rnd side ?pos bij_info tc
759+
let pos = pos |> Option.map (function
760+
| Single (b, p) ->
761+
let p =
762+
if Option.is_some side then
763+
EcProofTyping.tc1_process_codepos1 tc (side, p)
764+
else EcTyping.trans_codepos1 (FApi.tc1_env tc) p
765+
in Single (b, p)
766+
| Double ((b1, p1), (b2, p2)) ->
767+
let p1 = EcProofTyping.tc1_process_codepos1 tc (Some `Left , p1) in
768+
let p2 = EcProofTyping.tc1_process_codepos1 tc (Some `Right, p2) in
769+
Double ((b1, p1), (b2, p2))
770+
)
771+
in
772+
773+
t_equiv_rnd side ?pos bij_info tc
695774

696775
| _ -> tc_error !!tc "invalid arguments"
697776

@@ -713,3 +792,24 @@ let process_rndsem ~reduce side pos tc =
713792
| Some side when is_equivS concl ->
714793
t_equiv_rndsem reduce side pos tc
715794
| _ -> tc_error !!tc "invalid arguments"
795+
796+
let process_coupling side g tc =
797+
let concl = FApi.tc1_goal tc in
798+
799+
if not (is_equivS concl) then
800+
tc_error !!tc "coupling can only be used on pRHL goals"
801+
else
802+
let env = FApi.tc1_env tc in
803+
let es = tc1_as_equivS tc in
804+
let (_, muL), _ = tc1_last_rnd tc es.es_sl in
805+
let (_, muR), _ = tc1_last_rnd tc es.es_sr in
806+
let tyL = proj_distr_ty env (e_ty muL) in
807+
let tyR = proj_distr_ty env (e_ty muR) in
808+
809+
let coupling_ty =
810+
match side with
811+
| None -> tdistr (ttuple [tyL; tyR])
812+
| Some `Left -> tfun tyL tyR
813+
| Some `Right -> tfun tyR tyL in
814+
let g_form = TTC.tc1_process_prhl_form tc coupling_ty g in
815+
Core.t_equiv_coupling_r side g_form tc

src/phl/ecPhlRnd.mli

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,8 @@ val t_equiv_rnd : ?pos:semrndpos -> oside -> (mkbij_t option) pair -> backward
2424
(* -------------------------------------------------------------------- *)
2525
val process_rnd : oside -> psemrndpos option -> rnd_infos_t -> backward
2626

27+
(* -------------------------------------------------------------------- *)
28+
val process_coupling : oside -> pformula -> backward
29+
2730
(* -------------------------------------------------------------------- *)
2831
val process_rndsem : reduce:bool -> oside -> pcodepos1 -> backward

tests/coupling-rnd.ec

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
require import AllCore List.
2+
require import Distr DBool DList.
3+
(*---*) import Biased.
4+
require import StdBigop.
5+
(*---*) import Bigreal Bigreal.BRA.
6+
7+
type t.
8+
9+
op test : t -> bool.
10+
op D : t distr.
11+
12+
axiom D_ll : is_lossless D.
13+
14+
module B = {
15+
proc f() : bool = {
16+
var x: bool;
17+
var garbage: int;
18+
(* put some garbage to avoid the random sampling being the only instruction in a procedure *)
19+
garbage <- 0;
20+
x <$ dbiased (mu D test);
21+
return x;
22+
}
23+
proc g() : t = {
24+
var l: t;
25+
var garbage: int;
26+
garbage <- 0;
27+
l <$ D;
28+
return l;
29+
}
30+
}.
31+
32+
op c : (bool * t) distr = dmap D (fun x => (test x, x)).
33+
34+
(* prove the lemma by directly providing a coupling *)
35+
lemma B_coupling:
36+
equiv [ B.f ~ B.g : true ==> res{1} = test res{2} ].
37+
proof.
38+
proc.
39+
coupling c.
40+
wp; skip => /=; split; last first.
41+
+ move => a b.
42+
by rewrite /c => /supp_dmap [x] />.
43+
rewrite /iscoupling; split; last first.
44+
+ by rewrite /c dmap_comp /(\o) /= dmap_id.
45+
rewrite /c dmap_comp /(\o) /=.
46+
apply eq_distr => b.
47+
rewrite dbiased1E.
48+
rewrite clamp_id.
49+
+ exact mu_bounded.
50+
rewrite dmap1E /pred1 /(\o) /=.
51+
case b => _.
52+
+ by apply mu_eq => x /#.
53+
rewrite (mu_eq _ _ (predC test)).
54+
+ by smt().
55+
by rewrite mu_not D_ll.
56+
qed.
57+
58+
(* prove the lemma by bypr, which means we compute the probability at eash side separately. *)
59+
lemma B_coupling_bypr:
60+
equiv [ B.f ~ B.g : true ==> res{1} = test res{2} ].
61+
proof.
62+
bypr res{1} (test res{2}).
63+
+ smt().
64+
+ move => &1 &2 b.
65+
have ->: Pr[B.f() @ &1 : res = b] = mu1 (dbiased (mu D test)) b.
66+
+ byphoare => //.
67+
proc; rnd; wp; skip => />.
68+
have -> //: Pr[B.g() @ &2 : test res = b] = mu1 (dbiased (mu D test)) b.
69+
+ byphoare => //.
70+
proc; rnd; sp; skip => />.
71+
rewrite dbiased1E.
72+
rewrite clamp_id.
73+
+ exact mu_bounded.
74+
case b => _.
75+
+ apply mu_eq => x /#.
76+
rewrite (mu_eq _ _ (predC test)).
77+
+ by smt().
78+
by rewrite mu_not D_ll.
79+
qed.
80+
81+
op f (x : t) : bool = test x.
82+
83+
lemma B_compress:
84+
equiv [ B.g ~ B.f : true ==> test res{1} = res{2} ].
85+
proof.
86+
proc.
87+
coupling{1} f.
88+
wp; skip => @/predT /=.
89+
+ split.
90+
+ apply eq_distr => b.
91+
rewrite dbiased1E.
92+
rewrite clamp_id.
93+
+ exact mu_bounded.
94+
rewrite dmap1E /pred1 /(\o) /=.
95+
case b => _.
96+
+ by apply mu_eq => x /#.
97+
rewrite (mu_eq _ _ (predC test)).
98+
+ by smt().
99+
by rewrite mu_not D_ll.
100+
+ by move => @/f //.
101+
qed.

theories/distributions/Distr.ec

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,26 @@ move=> a1 a2 _ /=; apply/iscpl_dlet; first by apply: cpl_d=> /#.
14371437
by move=> b1 b2 _ /=; apply/iscpl_dunit.
14381438
qed.
14391439

1440+
(* -------------------------------------------------------------------- *)
1441+
lemma iscpl_dmap1 ['a 'b] (d1 : 'a distr) (d2 : 'b distr) (f : 'a -> 'b) :
1442+
dmap d1 f = d2 => iscoupling<:'a, 'b> d1 d2 (dmap d1 (fun x => (x, f x))).
1443+
proof.
1444+
rewrite /iscoupling => ?.
1445+
split.
1446+
+ by rewrite dmap_comp /(\o) /= dmap_id //.
1447+
+ by rewrite dmap_comp /(\o) /=.
1448+
qed.
1449+
1450+
(* -------------------------------------------------------------------- *)
1451+
lemma iscpl_dmap2 ['a 'b] (d1 : 'a distr) (d2 : 'b distr) (f : 'b -> 'a) :
1452+
dmap d2 f = d1 => iscoupling<:'a, 'b> d1 d2 (dmap d2 (fun x => (f x, x))).
1453+
proof.
1454+
rewrite /iscoupling => ?.
1455+
split.
1456+
+ by rewrite dmap_comp /(\o) /=.
1457+
+ by rewrite dmap_comp /(\o) /= dmap_id //.
1458+
qed.
1459+
14401460
(* -------------------------------------------------------------------- *)
14411461
abbrev [-printing] dapply (F: 'a -> 'b) : 'a distr -> 'b distr =
14421462
fun d => dmap d F.

0 commit comments

Comments
 (0)