diff --git a/src/ecHiTacticals.ml b/src/ecHiTacticals.ml index 30e0a7355..472054f64 100644 --- a/src/ecHiTacticals.ml +++ b/src/ecHiTacticals.ml @@ -200,6 +200,7 @@ and process1_phl (_ : ttenv) (t : phltactic located) (tc : tcenv1) = | Psetmatch info -> EcPhlCodeTx.process_set_match info | Pweakmem info -> EcPhlCodeTx.process_weakmem info | Prnd (side, pos, info) -> EcPhlRnd.process_rnd side pos info + | Pcoupling (side, f) -> EcPhlRnd.process_coupling side f | Prndsem (red, side, pos) -> EcPhlRnd.process_rndsem ~reduce:red side pos | Pconseq (opt, info) -> EcPhlConseq.process_conseq_opt opt info | Pconseqauto cm -> process_conseqauto cm diff --git a/src/ecLexer.mll b/src/ecLexer.mll index 0e6f9a02e..6170535ff 100644 --- a/src/ecLexer.mll +++ b/src/ecLexer.mll @@ -145,6 +145,7 @@ "cfold" , CFOLD ; (* KW: tactic *) "rnd" , RND ; (* KW: tactic *) "rndsem" , RNDSEM ; (* KW: tactic *) + "coupling" , COUPLING ; (* KW: tactic *) "pr_bounded" , PRBOUNDED ; (* KW: tactic *) "bypr" , BYPR ; (* KW: tactic *) "byphoare" , BYPHOARE ; (* KW: tactic *) diff --git a/src/ecParser.mly b/src/ecParser.mly index 81c38f90b..c152ed101 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -416,6 +416,7 @@ %token CONSEQ %token CONST %token COQ +%token COUPLING %token CHECK %token EDIT %token FIX @@ -3003,6 +3004,9 @@ interleave_info: | RNDSEM red=boption(STAR) s=side? c=codepos1 { Prndsem (red, s, c) } +| COUPLING s=side? f=sform + { Pcoupling (s, f) } + | INLINE s=side? u=inlineopt? o=occurences? { Pinline (`ByName(s, u, ([], o))) } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index bc090f693..c741fc94d 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -748,6 +748,7 @@ type phltactic = | Pasgncase of (oside * pcodepos) | Prnd of oside * psemrndpos option * rnd_tac_info_f | Prndsem of bool * oside * pcodepos1 + | Pcoupling of oside * pformula | Palias of (oside * pcodepos * osymbol_r) | Pweakmem of (oside * psymbol * fun_params) | Pset of (oside * pcodepos * bool * psymbol * pexpr) diff --git a/src/phl/ecPhlRnd.ml b/src/phl/ecPhlRnd.ml index 5ed923c2b..9c5b345f1 100644 --- a/src/phl/ecPhlRnd.ml +++ b/src/phl/ecPhlRnd.ml @@ -10,6 +10,7 @@ open EcSubst open EcMatching.Position open EcCoreGoal +open EcCoreFol open EcLowGoal open EcLowPhlGoal @@ -442,6 +443,84 @@ module Core = struct | `Right -> f_equivS (snd es.es_ml) mt (es_pr es) es.es_sl s (es_po es) in FApi.xmutate1 tc (`RndSem pos) [concl] + (* -------------------------------------------------------------------- *) + let t_equiv_coupling_r side g tc = + (* process the following pRHL goal, where g is a coupling of g1 and g2 *) + (* {phi} c1; x <$ g1 ~ c2; y <$ g2 {psi} *) + let env = FApi.tc1_env tc in + let es = tc1_as_equivS tc in + let ml = fst es.es_ml in + let mr = fst es.es_mr in + let (lvL, muL), sl' = tc1_last_rnd tc es.es_sl in + let (lvR, muR), sr' = tc1_last_rnd tc es.es_sr in + let tyL = proj_distr_ty env (e_ty muL) in + let tyR = proj_distr_ty env (e_ty muR) in + let muL = (EcFol.ss_inv_of_expr ml muL).inv in + let muR = (EcFol.ss_inv_of_expr mr muR).inv in + let g = g.inv in + + let goal = + match side with + | None -> + (* Goal: {phi} c1 ~ c2 {iscoupling g muL muR /\ (forall u v, (u, v) \in supp(g) => psi[x -> u, y -> v])} *) + (* Generate two free variables u and v and the pair (u, v) *) + let u_id = EcIdent.create "u" in + let v_id = EcIdent.create "v" in + let u = f_local u_id tyL in + let v = f_local v_id tyR in + let ab = f_tuple [u; v] in + + (* Generate the coupling distribution type: (tyL * tyR) distr *) + let coupling_ty = ttuple [tyL; tyR] in + let g_app = f_app_simpl g [] (tdistr coupling_ty) in + + let iscoupling_op = EcPath.extend EcCoreLib.p_top ["Distr"; "iscoupling"] in + let iscoupling_ty = tfun (tdistr tyL) (tfun (tdistr tyR) (tfun (tdistr coupling_ty) tbool)) in + let iscoupling_pred = f_app (f_op iscoupling_op [tyL; tyR] iscoupling_ty) [muL; muR; g_app] tbool in + + (* Substitute in the postcondition *) + let post = (es_po es) in + let post_subst = subst_form_lv_left env lvL {ml=ml;mr=mr;inv=u} post in + let post_subst = subst_form_lv_right env lvR {ml=ml;mr=mr;inv=v} post_subst in + + let goal = map_ts_inv1 (f_imp (f_in_supp ab g_app)) post_subst in + let goal = map_ts_inv1 (f_forall_simpl [(u_id, GTty tyL); (v_id, GTty tyR)]) goal in + map_ts_inv1 (f_and iscoupling_pred) goal + | Some side -> + (* Goal (left): {phi} c1 ~ c2 {dmap d1 g = d2 /\ + forall u, u \in supp(d1) -> psi[x -> u, y -> g(u)]} *) + (* Goal (right): {phi} c1 ~ c2 {d1 = dmap d2 g /\ + forall v, v \in supp(d2) -> psi[x -> g(v), y -> v]} *) + let is_left = (side = `Left) in + let (mu_main, mu_other) = if is_left then (muL, muR) else (muR, muL) in + let (ty_main, ty_other) = if is_left then (tyL, tyR) else (tyR, tyL) in + let (lv_main, lv_other) = if is_left then (lvL, lvR) else (lvR, lvL) in + let (subst_main, subst_other) = + if is_left then (subst_form_lv_left, subst_form_lv_right) + else (subst_form_lv_right, subst_form_lv_left) in + + let dmap_op = EcPath.extend EcCoreLib.p_top ["Distr"; "dmap"] in + let dmap_ty = tfun (tdistr ty_main) (tfun (tfun ty_main ty_other) (tdistr ty_other)) in + let dmap_pred = f_app (f_op dmap_op [ty_main; ty_other] dmap_ty) [mu_main ; g] (tdistr ty_other) in + let dmap_eq = f_eq dmap_pred mu_other in + + let var_name = if is_left then "u" else "v" in + let var_ty = if is_left then tyL else tyR in + let var_id = EcIdent.create var_name in + let var = f_local var_id var_ty in + let g_applied = f_app_simpl g [var] (if is_left then tyR else tyL) in + + let post = es_po es in + let post_subst = subst_main env lv_main {ml;mr;inv=var} post in + let post_subst = subst_other env lv_other {ml;mr;inv=g_applied} post_subst in + let goal = map_ts_inv1 (f_imp (f_in_supp var mu_main)) post_subst in + let goal = map_ts_inv1 (f_forall_simpl [(var_id, GTty var_ty)]) goal in + + map_ts_inv1 (f_and dmap_eq) goal + in + let goal = f_equivS (snd es.es_ml) (snd es.es_mr) (es_pr es) sl' sr' goal in + + FApi.xmutate1 tc `Rnd [goal] end (* Core *) (* -------------------------------------------------------------------- *) @@ -713,3 +792,24 @@ let process_rndsem ~reduce side pos tc = | Some side when is_equivS concl -> t_equiv_rndsem reduce side pos tc | _ -> tc_error !!tc "invalid arguments" + +let process_coupling side g tc = + let concl = FApi.tc1_goal tc in + + if not (is_equivS concl) then + tc_error !!tc "coupling can only be used on pRHL goals" + else + let env = FApi.tc1_env tc in + let es = tc1_as_equivS tc in + let (_, muL), _ = tc1_last_rnd tc es.es_sl in + let (_, muR), _ = tc1_last_rnd tc es.es_sr in + let tyL = proj_distr_ty env (e_ty muL) in + let tyR = proj_distr_ty env (e_ty muR) in + + let coupling_ty = + match side with + | None -> tdistr (ttuple [tyL; tyR]) + | Some `Left -> tfun tyL tyR + | Some `Right -> tfun tyR tyL in + let g_form = TTC.tc1_process_prhl_form tc coupling_ty g in + Core.t_equiv_coupling_r side g_form tc diff --git a/src/phl/ecPhlRnd.mli b/src/phl/ecPhlRnd.mli index e636ac6d9..b7188ebc0 100644 --- a/src/phl/ecPhlRnd.mli +++ b/src/phl/ecPhlRnd.mli @@ -24,5 +24,8 @@ val t_equiv_rnd : ?pos:semrndpos -> oside -> (mkbij_t option) pair -> backward (* -------------------------------------------------------------------- *) val process_rnd : oside -> psemrndpos option -> rnd_infos_t -> backward +(* -------------------------------------------------------------------- *) +val process_coupling : oside -> pformula -> backward + (* -------------------------------------------------------------------- *) val process_rndsem : reduce:bool -> oside -> pcodepos1 -> backward diff --git a/tests/coupling-rnd.ec b/tests/coupling-rnd.ec new file mode 100644 index 000000000..8f4dd7169 --- /dev/null +++ b/tests/coupling-rnd.ec @@ -0,0 +1,121 @@ +require import AllCore List. +require import Distr DBool DList. +(*---*) import Biased. +require import StdBigop. +(*---*) import Bigreal Bigreal.BRA. + +type t. + +op test : t -> bool. +op D : t distr. + +axiom D_ll : is_lossless D. + +module B = { + proc f() : bool = { + var x: bool; + var garbage: int; + (* put some garbage to avoid the random sampling being the only instruction in a procedure *) + garbage <- 0; + x <$ dbiased (mu D test); + return x; + } + proc g() : t = { + var l: t; + var garbage: int; + garbage <- 0; + l <$ D; + return l; + } +}. + +op c : (bool * t) distr = dmap D (fun x => (test x, x)). + +(* prove the lemma by directly providing a coupling *) +lemma B_coupling: + equiv [ B.f ~ B.g : true ==> res{1} = test res{2} ]. +proof. +proc. +coupling c. +wp; skip => /=; split; last first. ++ move => a b. + by rewrite /c => /supp_dmap [x] />. +rewrite /iscoupling; split; last first. ++ by rewrite /c dmap_comp /(\o) /= dmap_id. +rewrite /c dmap_comp /(\o) /=. +apply eq_distr => b. +rewrite dbiased1E. +rewrite clamp_id. ++ exact mu_bounded. +rewrite dmap1E /pred1 /(\o) /=. +case b => _. ++ by apply mu_eq => x /#. +rewrite (mu_eq _ _ (predC test)). ++ by smt(). +by rewrite mu_not D_ll. +qed. + +(* prove the lemma by bypr, which means we compute the probability at each side separately. *) +lemma B_coupling_bypr: + equiv [ B.f ~ B.g : true ==> res{1} = test res{2} ]. +proof. +bypr res{1} (test res{2}). ++ smt(). ++ move => &1 &2 b. + have ->: Pr[B.f() @ &1 : res = b] = mu1 (dbiased (mu D test)) b. + + byphoare => //. + proc; rnd; wp; skip => />. + have -> //: Pr[B.g() @ &2 : test res = b] = mu1 (dbiased (mu D test)) b. + + byphoare => //. + proc; rnd; sp; skip => />. + rewrite dbiased1E. + rewrite clamp_id. + + exact mu_bounded. + case b => _. + + apply mu_eq => x /#. + rewrite (mu_eq _ _ (predC test)). + + by smt(). + by rewrite mu_not D_ll. +qed. + +op f (x : t) : bool = test x. + +lemma B_compress: + equiv [ B.g ~ B.f : true ==> test res{1} = res{2} ]. +proof. +proc. +coupling{1} f. +wp; skip => @/predT /=. ++ split. + + apply eq_distr => b. + rewrite dbiased1E. + rewrite clamp_id. + + exact mu_bounded. + rewrite dmap1E /pred1 /(\o) /=. + case b => _. + + by apply mu_eq => x /#. + rewrite (mu_eq _ _ (predC test)). + + by smt(). + by rewrite mu_not D_ll. + + by move => @/f * //=. +qed. + +lemma B_compress_reverse: + equiv [ B.f ~ B.g : true ==> test res{2} = res{1} ]. +proof. +proc. +coupling{2} f. +wp; skip => @/predT /=. ++ split. + + apply eq_distr => b. + rewrite dbiased1E. + rewrite clamp_id. + + exact mu_bounded. + rewrite dmap1E /pred1 /(\o) /=. + case b => _. + + by apply mu_eq => x /#. + rewrite (mu_eq _ (fun (x : t) => f x = false) (predC test)). + + by smt(). + by rewrite mu_not D_ll. + + by move => @/f //. +qed. diff --git a/theories/distributions/Distr.ec b/theories/distributions/Distr.ec index f4c8ddf94..6738bbc63 100644 --- a/theories/distributions/Distr.ec +++ b/theories/distributions/Distr.ec @@ -1437,6 +1437,26 @@ move=> a1 a2 _ /=; apply/iscpl_dlet; first by apply: cpl_d=> /#. by move=> b1 b2 _ /=; apply/iscpl_dunit. qed. +(* -------------------------------------------------------------------- *) +lemma iscpl_dmap1 ['a 'b] (d1 : 'a distr) (d2 : 'b distr) (f : 'a -> 'b) : + dmap d1 f = d2 => iscoupling<:'a, 'b> d1 d2 (dmap d1 (fun x => (x, f x))). +proof. +rewrite /iscoupling => ?. +split. ++ by rewrite dmap_comp /(\o) /= dmap_id //. ++ by rewrite dmap_comp /(\o) /=. +qed. + +(* -------------------------------------------------------------------- *) +lemma iscpl_dmap2 ['a 'b] (d1 : 'a distr) (d2 : 'b distr) (f : 'b -> 'a) : + dmap d2 f = d1 => iscoupling<:'a, 'b> d1 d2 (dmap d2 (fun x => (f x, x))). +proof. +rewrite /iscoupling => ?. +split. ++ by rewrite dmap_comp /(\o) /=. ++ by rewrite dmap_comp /(\o) /= dmap_id //. +qed. + (* -------------------------------------------------------------------- *) abbrev [-printing] dapply (F: 'a -> 'b) : 'a distr -> 'b distr = fun d => dmap d F.