Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ecHiTacticals.ml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ and process1_phl (_ : ttenv) (t : phltactic located) (tc : tcenv1) =
| Psetmatch info -> EcPhlCodeTx.process_set_match info
| Pweakmem info -> EcPhlCodeTx.process_weakmem info
| Prnd (side, pos, info) -> EcPhlRnd.process_rnd side pos info
| Pcoupling (side, f) -> EcPhlRnd.process_coupling side f
| Prndsem (red, side, pos) -> EcPhlRnd.process_rndsem ~reduce:red side pos
| Pconseq (opt, info) -> EcPhlConseq.process_conseq_opt opt info
| Pconseqauto cm -> process_conseqauto cm
Expand Down
1 change: 1 addition & 0 deletions src/ecLexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
"cfold" , CFOLD ; (* KW: tactic *)
"rnd" , RND ; (* KW: tactic *)
"rndsem" , RNDSEM ; (* KW: tactic *)
"coupling" , COUPLING ; (* KW: tactic *)
"pr_bounded" , PRBOUNDED ; (* KW: tactic *)
"bypr" , BYPR ; (* KW: tactic *)
"byphoare" , BYPHOARE ; (* KW: tactic *)
Expand Down
4 changes: 4 additions & 0 deletions src/ecParser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@
%token CONSEQ
%token CONST
%token COQ
%token COUPLING
%token CHECK
%token EDIT
%token FIX
Expand Down Expand Up @@ -3003,6 +3004,9 @@ interleave_info:
| RNDSEM red=boption(STAR) s=side? c=codepos1
{ Prndsem (red, s, c) }

| COUPLING s=side? f=sform
{ Pcoupling (s, f) }

| INLINE s=side? u=inlineopt? o=occurences?
{ Pinline (`ByName(s, u, ([], o))) }

Expand Down
1 change: 1 addition & 0 deletions src/ecParsetree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ type phltactic =
| Pasgncase of (oside * pcodepos)
| Prnd of oside * psemrndpos option * rnd_tac_info_f
| Prndsem of bool * oside * pcodepos1
| Pcoupling of oside * pformula
| Palias of (oside * pcodepos * osymbol_r)
| Pweakmem of (oside * psymbol * fun_params)
| Pset of (oside * pcodepos * bool * psymbol * pexpr)
Expand Down
100 changes: 100 additions & 0 deletions src/phl/ecPhlRnd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ open EcSubst

open EcMatching.Position
open EcCoreGoal
open EcCoreFol
open EcLowGoal
open EcLowPhlGoal

Expand Down Expand Up @@ -442,6 +443,84 @@ module Core = struct
| `Right -> f_equivS (snd es.es_ml) mt (es_pr es) es.es_sl s (es_po es) in
FApi.xmutate1 tc (`RndSem pos) [concl]

(* -------------------------------------------------------------------- *)
let t_equiv_coupling_r side g tc =
(* process the following pRHL goal, where g is a coupling of g1 and g2 *)
(* {phi} c1; x <$ g1 ~ c2; y <$ g2 {psi} *)
let env = FApi.tc1_env tc in
let es = tc1_as_equivS tc in
let ml = fst es.es_ml in
let mr = fst es.es_mr in
let (lvL, muL), sl' = tc1_last_rnd tc es.es_sl in
let (lvR, muR), sr' = tc1_last_rnd tc es.es_sr in
let tyL = proj_distr_ty env (e_ty muL) in
let tyR = proj_distr_ty env (e_ty muR) in
let muL = (EcFol.ss_inv_of_expr ml muL).inv in
let muR = (EcFol.ss_inv_of_expr mr muR).inv in
let g = g.inv in

let goal =
match side with
| None ->
(* Goal: {phi} c1 ~ c2 {iscoupling g muL muR /\ (forall u v, (u, v) \in supp(g) => psi[x -> u, y -> v])} *)
(* Generate two free variables u and v and the pair (u, v) *)
let u_id = EcIdent.create "u" in
let v_id = EcIdent.create "v" in
let u = f_local u_id tyL in
let v = f_local v_id tyR in
let ab = f_tuple [u; v] in

(* Generate the coupling distribution type: (tyL * tyR) distr *)
let coupling_ty = ttuple [tyL; tyR] in
let g_app = f_app_simpl g [] (tdistr coupling_ty) in

let iscoupling_op = EcPath.extend EcCoreLib.p_top ["Distr"; "iscoupling"] in
let iscoupling_ty = tfun (tdistr tyL) (tfun (tdistr tyR) (tfun (tdistr coupling_ty) tbool)) in
let iscoupling_pred = f_app (f_op iscoupling_op [tyL; tyR] iscoupling_ty) [muL; muR; g_app] tbool in

(* Substitute in the postcondition *)
let post = (es_po es) in
let post_subst = subst_form_lv_left env lvL {ml=ml;mr=mr;inv=u} post in
let post_subst = subst_form_lv_right env lvR {ml=ml;mr=mr;inv=v} post_subst in

let goal = map_ts_inv1 (f_imp (f_in_supp ab g_app)) post_subst in
let goal = map_ts_inv1 (f_forall_simpl [(u_id, GTty tyL); (v_id, GTty tyR)]) goal in
map_ts_inv1 (f_and iscoupling_pred) goal
| Some side ->
(* Goal (left): {phi} c1 ~ c2 {dmap d1 g = d2 /\
forall u, u \in supp(d1) -> psi[x -> u, y -> g(u)]} *)
(* Goal (right): {phi} c1 ~ c2 {d1 = dmap d2 g /\
forall v, v \in supp(d2) -> psi[x -> g(v), y -> v]} *)
let is_left = (side = `Left) in
let (mu_main, mu_other) = if is_left then (muL, muR) else (muR, muL) in
let (ty_main, ty_other) = if is_left then (tyL, tyR) else (tyR, tyL) in
let (lv_main, lv_other) = if is_left then (lvL, lvR) else (lvR, lvL) in
let (subst_main, subst_other) =
if is_left then (subst_form_lv_left, subst_form_lv_right)
else (subst_form_lv_right, subst_form_lv_left) in

let dmap_op = EcPath.extend EcCoreLib.p_top ["Distr"; "dmap"] in
let dmap_ty = tfun (tdistr ty_main) (tfun (tfun ty_main ty_other) (tdistr ty_other)) in
let dmap_pred = f_app (f_op dmap_op [ty_main; ty_other] dmap_ty) [mu_main ; g] (tdistr ty_other) in
let dmap_eq = f_eq dmap_pred mu_other in

let var_name = if is_left then "u" else "v" in
let var_ty = if is_left then tyL else tyR in
let var_id = EcIdent.create var_name in
let var = f_local var_id var_ty in
let g_applied = f_app_simpl g [var] (if is_left then tyR else tyL) in

let post = es_po es in
let post_subst = subst_main env lv_main {ml;mr;inv=var} post in
let post_subst = subst_other env lv_other {ml;mr;inv=g_applied} post_subst in
let goal = map_ts_inv1 (f_imp (f_in_supp var mu_main)) post_subst in
let goal = map_ts_inv1 (f_forall_simpl [(var_id, GTty var_ty)]) goal in

map_ts_inv1 (f_and dmap_eq) goal
in
let goal = f_equivS (snd es.es_ml) (snd es.es_mr) (es_pr es) sl' sr' goal in

FApi.xmutate1 tc `Rnd [goal]
end (* Core *)

(* -------------------------------------------------------------------- *)
Expand Down Expand Up @@ -713,3 +792,24 @@ let process_rndsem ~reduce side pos tc =
| Some side when is_equivS concl ->
t_equiv_rndsem reduce side pos tc
| _ -> tc_error !!tc "invalid arguments"

let process_coupling side g tc =
let concl = FApi.tc1_goal tc in

if not (is_equivS concl) then
tc_error !!tc "coupling can only be used on pRHL goals"
else
let env = FApi.tc1_env tc in
let es = tc1_as_equivS tc in
let (_, muL), _ = tc1_last_rnd tc es.es_sl in
let (_, muR), _ = tc1_last_rnd tc es.es_sr in
let tyL = proj_distr_ty env (e_ty muL) in
let tyR = proj_distr_ty env (e_ty muR) in

let coupling_ty =
match side with
| None -> tdistr (ttuple [tyL; tyR])
| Some `Left -> tfun tyL tyR
| Some `Right -> tfun tyR tyL in
let g_form = TTC.tc1_process_prhl_form tc coupling_ty g in
Core.t_equiv_coupling_r side g_form tc
3 changes: 3 additions & 0 deletions src/phl/ecPhlRnd.mli
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,8 @@ val t_equiv_rnd : ?pos:semrndpos -> oside -> (mkbij_t option) pair -> backward
(* -------------------------------------------------------------------- *)
val process_rnd : oside -> psemrndpos option -> rnd_infos_t -> backward

(* -------------------------------------------------------------------- *)
val process_coupling : oside -> pformula -> backward

(* -------------------------------------------------------------------- *)
val process_rndsem : reduce:bool -> oside -> pcodepos1 -> backward
121 changes: 121 additions & 0 deletions tests/coupling-rnd.ec
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
require import AllCore List.
require import Distr DBool DList.
(*---*) import Biased.
require import StdBigop.
(*---*) import Bigreal Bigreal.BRA.

type t.

op test : t -> bool.
op D : t distr.

axiom D_ll : is_lossless D.

module B = {
proc f() : bool = {
var x: bool;
var garbage: int;
(* put some garbage to avoid the random sampling being the only instruction in a procedure *)
garbage <- 0;
x <$ dbiased (mu D test);
return x;
}
proc g() : t = {
var l: t;
var garbage: int;
garbage <- 0;
l <$ D;
return l;
}
}.

op c : (bool * t) distr = dmap D (fun x => (test x, x)).

(* prove the lemma by directly providing a coupling *)
lemma B_coupling:
equiv [ B.f ~ B.g : true ==> res{1} = test res{2} ].
proof.
proc.
coupling c.
wp; skip => /=; split; last first.
+ move => a b.
by rewrite /c => /supp_dmap [x] />.
rewrite /iscoupling; split; last first.
+ by rewrite /c dmap_comp /(\o) /= dmap_id.
rewrite /c dmap_comp /(\o) /=.
apply eq_distr => b.
rewrite dbiased1E.
rewrite clamp_id.
+ exact mu_bounded.
rewrite dmap1E /pred1 /(\o) /=.
case b => _.
+ by apply mu_eq => x /#.
rewrite (mu_eq _ _ (predC test)).
+ by smt().
by rewrite mu_not D_ll.
qed.

(* prove the lemma by bypr, which means we compute the probability at each side separately. *)
lemma B_coupling_bypr:
equiv [ B.f ~ B.g : true ==> res{1} = test res{2} ].
proof.
bypr res{1} (test res{2}).
+ smt().
+ move => &1 &2 b.
have ->: Pr[B.f() @ &1 : res = b] = mu1 (dbiased (mu D test)) b.
+ byphoare => //.
proc; rnd; wp; skip => />.
have -> //: Pr[B.g() @ &2 : test res = b] = mu1 (dbiased (mu D test)) b.
+ byphoare => //.
proc; rnd; sp; skip => />.
rewrite dbiased1E.
rewrite clamp_id.
+ exact mu_bounded.
case b => _.
+ apply mu_eq => x /#.
rewrite (mu_eq _ _ (predC test)).
+ by smt().
by rewrite mu_not D_ll.
qed.

op f (x : t) : bool = test x.

lemma B_compress:
equiv [ B.g ~ B.f : true ==> test res{1} = res{2} ].
proof.
proc.
coupling{1} f.
wp; skip => @/predT /=.
+ split.
+ apply eq_distr => b.
rewrite dbiased1E.
rewrite clamp_id.
+ exact mu_bounded.
rewrite dmap1E /pred1 /(\o) /=.
case b => _.
+ by apply mu_eq => x /#.
rewrite (mu_eq _ _ (predC test)).
+ by smt().
by rewrite mu_not D_ll.
+ by move => @/f * //=.
qed.

lemma B_compress_reverse:
equiv [ B.f ~ B.g : true ==> test res{2} = res{1} ].
proof.
proc.
coupling{2} f.
wp; skip => @/predT /=.
+ split.
+ apply eq_distr => b.
rewrite dbiased1E.
rewrite clamp_id.
+ exact mu_bounded.
rewrite dmap1E /pred1 /(\o) /=.
case b => _.
+ by apply mu_eq => x /#.
rewrite (mu_eq _ (fun (x : t) => f x = false) (predC test)).
+ by smt().
by rewrite mu_not D_ll.
+ by move => @/f //.
qed.
20 changes: 20 additions & 0 deletions theories/distributions/Distr.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,26 @@ move=> a1 a2 _ /=; apply/iscpl_dlet; first by apply: cpl_d=> /#.
by move=> b1 b2 _ /=; apply/iscpl_dunit.
qed.

(* -------------------------------------------------------------------- *)
lemma iscpl_dmap1 ['a 'b] (d1 : 'a distr) (d2 : 'b distr) (f : 'a -> 'b) :
dmap d1 f = d2 => iscoupling<:'a, 'b> d1 d2 (dmap d1 (fun x => (x, f x))).
proof.
rewrite /iscoupling => ?.
split.
+ by rewrite dmap_comp /(\o) /= dmap_id //.
+ by rewrite dmap_comp /(\o) /=.
qed.

(* -------------------------------------------------------------------- *)
lemma iscpl_dmap2 ['a 'b] (d1 : 'a distr) (d2 : 'b distr) (f : 'b -> 'a) :
dmap d2 f = d1 => iscoupling<:'a, 'b> d1 d2 (dmap d2 (fun x => (f x, x))).
proof.
rewrite /iscoupling => ?.
split.
+ by rewrite dmap_comp /(\o) /=.
+ by rewrite dmap_comp /(\o) /= dmap_id //.
qed.

(* -------------------------------------------------------------------- *)
abbrev [-printing] dapply (F: 'a -> 'b) : 'a distr -> 'b distr =
fun d => dmap d F.
Expand Down