https://github.com/EasyCrypt/easycrypt
Raw File
Tip revision: 955e909402cf7a5dc3dc55e4de13bbf373edd920 authored by Pierre-Yves Strub on 30 July 2015, 08:20:28 UTC
NewList: last_ -> last.
Tip revision: 955e909
Monoid.ec
(* --------------------------------------------------------------------
 * Copyright (c) - 2012-2015 - IMDEA Software Institute and INRIA
 * Distributed under the terms of the CeCILL-B licence.
 * -------------------------------------------------------------------- *)

theory Comoid.

theory Base.
  type t.

  op Z: t.
  op (+): t -> t -> t.

  axiom addmC (x y:t): x + y = y + x.
  axiom addmA (x y z:t): x + (y + z) = x + y + z.
  axiom addmZ (x:t): x + Z = x.
end Base.
export Base.

lemma addmCA (x y z:t): (x + y) + z = (x + z) + y.
proof strict.
by rewrite -addmA (addmC y) addmA.
qed.

lemma addmAC (x y z:t): x + (y + z) = y + (x + z).
proof strict.
by rewrite addmA (addmC x) -addmA.
qed.

lemma addmACA (x y z t:t):
  (x + y) + (z + t) = (x + z) + (y + t).
proof strict.
by rewrite addmA -(addmA x) (addmC y) !addmA.
qed.

require import FSet.

op sum (f:'a -> t) (s:'a set) =
  fold (fun x s, s + (f x)) Z s.

lemma sum_empty (f:'a -> t): sum f empty = Z.
proof strict.
by rewrite /sum fold_empty.
qed.

lemma sum_rm (f:'a -> t) (s:'a set) (x:'a):
  mem x s =>
  sum f s = (f x) + (sum f (rm x s)).
proof strict.
rewrite /sum=> x_in_s.
rewrite (foldC x) // /=; last by rewrite addmC.
by intros=> a b X; rewrite addmCA.
qed.

lemma sum_add (f:'a -> t) (s:'a set) (x:'a):
  (!mem x s) =>
  sum f (add x s) = (f x) + (sum f s).
proof strict.
intros=> x_nin_s;
rewrite (sum_rm _ _ x); first by rewrite mem_add.
by rewrite rm_add_eq -rm_nin_id.
qed.

lemma sum_add0 (f:'a -> t) (s:'a set) (x:'a):
  (mem x s => f x = Z) =>
  sum f (add x s) = (f x) + (sum f s).
proof strict.
case (mem x s) => /= Hin.
  by rewrite -add_in_id // => ->;rewrite addmC addmZ.
by apply sum_add.  
qed.

lemma sum_disj (f:'a -> t) (s1 s2:'a set) :
  disjoint s1 s2 =>
  sum f (union s1 s2) = sum f s1 + sum f s2.
proof -strict.
 elim/set_ind s1.
   by intros Hd;rewrite union0s sum_empty addmC addmZ.
 intros x s Hx Hrec Hd;rewrite union_add sum_add.
   by generalize Hd;rewrite disjoint_spec mem_union;smt.
 rewrite sum_add.
   by generalize Hd;rewrite disjoint_spec;smt.
 rewrite Hrec.
   move: Hd; rewrite !disjoint_spec=> Hd x0.
   cut:= Hd x0; case (x0 = x).
     move=> ->; cut -> //=: mem x (add x s) by smt. smt.
   by rewrite -neqF mem_add=> ->.
 smt.
qed.

lemma sum_eq (f1 f2:'a -> t) (s: 'a set) :  
   (forall x, mem x s => f1 x = f2 x) =>
   sum f1 s = sum f2 s.
proof strict.
  elim/set_ind s.
    by rewrite !sum_empty.
  intros x s Hx Hr Hf;rewrite sum_add // sum_add // Hf;first smt.
  by rewrite Hr // => y Hin;apply Hf;smt.
qed.

lemma sum_in (f:'a -> t) (s:'a set):
  sum f s = sum (fun x, if mem x s then f x else Z) s.
proof strict.
  by apply sum_eq => x /= ->.
qed.

lemma sum_comp (f: t -> t) (g:'a -> t) (s: 'a set):
  (f Z = Z) =>
  (forall x y, f (x + y) = f x + f y) =>
  sum (fun a, f (g a)) s = f (sum g s).
proof -strict.
  intros Hz Ha;elim/set_ind s.
    by rewrite !sum_empty Hz.
  by intros x s Hx Hr;rewrite sum_add // sum_add //= Hr Ha. 
qed.

lemma sum_add2 (f:'a -> t) (g:'a -> t) (s:'a set):
  (sum f s) + (sum g s) = sum (fun x, f x + g x) s.
proof strict.
elim/set_comp s;first by rewrite !sum_empty addmZ.
intros s s_nempty IH;
rewrite (sum_rm f _ (pick s)); first by rewrite mem_pick.
rewrite (sum_rm g _ (pick s)); first by rewrite mem_pick.
rewrite (sum_rm _ s (pick s)); first by rewrite mem_pick.
by rewrite -IH /= addmACA.
qed.

lemma sum_chind (f:'a -> t) (g:'a -> 'b) (g':'b -> 'a) (s:'a set):
  (forall x, mem x s => g' (g x) = x) =>
  (sum f s) = sum (fun x, f (g' x)) (img g s).
proof strict.
intros=> pcan_g'_g;
elim/set_comp {1 3 4}s (leq_refl s).
  by rewrite !sum_empty img_empty sum_empty.
  intros s' s'_nempty IH leq_s'_s;
  rewrite (sum_rm _ _ (pick s'));first by rewrite mem_pick.
  rewrite (sum_rm _ (img g s') (g (pick s'))) /=;
    first by rewrite mem_img // mem_pick.
  rewrite pcan_g'_g; first by apply leq_s'_s; apply mem_pick.
  rewrite IH; first apply (leq_tran s')=> //; apply rm_leq.
  rewrite img_rm;
  (cut ->: (forall x, mem x s' => g (pick s') = g x => pick s' = x) = true)=> //;
  rewrite eqT=> x x_in_s g_pick.
  rewrite -pcan_g'_g; first by apply leq_s'_s.
  by rewrite -g_pick pcan_g'_g //; apply leq_s'_s; apply mem_pick.
qed.

lemma sum_filter (f:'a -> t) (p:'a -> bool) (s:'a set):
  (forall x, (!p x) => f x = Z) =>
  sum f (filter p s) = sum f s.
proof strict.
intros=> f_Z; elim/set_comp {1 3 4}s (leq_refl s).
  by rewrite FSet.filter_empty.
  intros=> s' s'_nempty IH leq_s'_s;
  rewrite (sum_rm _ s' (pick s')); first by apply mem_pick.
  rewrite -IH;first apply (leq_tran s')=> //; apply rm_leq.
  case (p (pick s'))=> p_pick.
    by rewrite (sum_rm _ (filter p s') (pick s')) ?rm_filter // mem_filter;
       split=> //; apply mem_pick.
    by rewrite f_Z // -rm_filter addmC addmZ -rm_nin_id // mem_filter -nand;
       right=> //.
qed.

require import Int.
import Interval.

op sum_ij (i j : int) (f:int -> t)  = 
  sum f (interval i j).

lemma sum_ij_gt (i j:int) f : 
  i > j => sum_ij i j f = Z.
proof -strict.
 by intros Hlt;rewrite /sum_ij interval_neg // sum_empty.
qed.

lemma sum_ij_split (k i j:int) f:
  i <= k <= j + 1 => sum_ij i j f = sum_ij i (k-1) f + sum_ij k j f.
proof -strict. 
  intros Hbound;rewrite /sum_ij -sum_disj.
    rewrite disjoint_spec=> x;rewrite !Interval.mem_interval;smt.
  congr=> //; apply set_ext=> x; rewrite mem_union; smt.
qed.

lemma sum_ij_eq i f: sum_ij i i f = f i.
proof -strict.
 rewrite /sum_ij Interval.interval_single sum_add;first apply mem_empty.
 rewrite sum_empty;apply addmZ.
qed.

lemma sum_ij_le_r (i j:int) f : 
   i <= j =>
   sum_ij i j f = sum_ij i (j-1) f + f j.
proof -strict.
  intros Hle;rewrite (sum_ij_split j); first by smt.
  by rewrite sum_ij_eq.
qed.

lemma sum_ij_le_l (i j:int) f : 
   i <= j =>
   sum_ij i j f = f i + sum_ij (i+1) j f.
proof -strict.
 intros Hle; rewrite (sum_ij_split (i+1));smt.
qed.

lemma sum_ij_shf (k i j:int) f:
   sum_ij i j f = sum_ij (i-k) (j-k) (fun n, f (k+n)).
proof strict.
  rewrite /sum_ij.
  rewrite (sum_chind f (fun n, n - k) (fun n, k + n)) /=;first smt.
  congr => //;apply set_ext => x;rewrite Interval.mem_interval img_def /=;split.
  intros [x0 [Heq ]];rewrite Interval.mem_interval;subst;smt.
  intros _;exists (x + k);smt.
qed.

lemma sum_ij_shf0 (i j :int) f:
   sum_ij i j f = sum_ij 0 (j-i) (fun n, f (i+n)).
proof strict.
  rewrite (sum_ij_shf i);smt.
qed.

theory NatMul.

  op ( * ) : int -> t -> t.

  axiom MulZ : forall (x:t), 0*x = Z.
  axiom MulS : forall n (x:t), 0 <= n => (n + 1) * x = x + n * x.

  lemma sum_const (k:t) (f:'a->t) (s:'a set):
    (forall (x:'a), mem x s => f x = k) =>
    sum f s = (card s)*k.
  proof strict.
  intros=> f_x; pose s' := s.
  cut -> //: s' <= s => sum f s' = (card s') * k;
    last by rewrite /s'; apply leq_refl<:'a>. (* FIXME *)
  elim/set_comp s'.
    by rewrite sum_empty card_empty MulZ.
    intros=> s' s'_nempty IH leq_s'_s.
    rewrite (sum_rm _ _ (pick s'));first by rewrite mem_pick.
    rewrite IH; first by apply (leq_tran s')=> //; apply rm_leq.
    rewrite f_x; first by apply leq_s'_s; apply mem_pick.
    rewrite card_rm_in; first by apply mem_pick.
    rewrite -MulS; smt.
  qed.
end NatMul.

end Comoid.

(* For bool *)
require Bool.

clone Comoid as Mbor with 
   type Base.t <- bool,
   op Base.(+) <- (\/),
   op Base.Z   <- false,
   op NatMul.( * ) = fun (n:int) (b:bool), n <> 0 /\ b
   proof Base.* by smt, NatMul.* by smt.

(* For int *)

theory Miplus.
  clone export Comoid as Miplus with
    type Base.t     <- int,
    op Base.(+)     <- Int.(+),
    op Base.Z       <- 0,
    op NatMul.( * ) <- Int.( * )
    proof Base.* by smt, NatMul.* by smt.

  import Int.
  op sum_n i j = sum_ij i j (fun (n:int), n).

  lemma sum_n_0k (k:int) : 0 <= k => sum_n 0 k = (k*(k + 1))/%2.
  proof -strict.
    rewrite /sum_n;elim /Int.Induction.induction k.
      by rewrite sum_ij_eq => /=; smt all.
    intros k Hk Hrec;rewrite sum_ij_le_r;first smt.
    cut -> : k + 1 - 1 = k;first smt.
    rewrite Hrec /=.
    have ->: (k + 1) * (k + 1 + 1) = k * (k + 1) + 2 * (k + 1) by smt.
    rewrite (CommutativeGroup.Comm.Comm (k * (k + 1))) Div_mult 1:smt.
    by rewrite (CommutativeGroup.Comm.Comm).
  qed.

 lemma sum_n_ii (k:int): sum_n k k = k
 by [].
 
 lemma sum_n_ij1 (i j:int) : i <= j => sum_n i (j+1) = sum_n i j + (j+1)
 by [].

 lemma sum_n_i1j (i j : int) : i <= j => i + sum_n (i+1) j = sum_n i j
 by [].

 lemma nosmt sumn_ij_aux (i j:int) : i <= j =>
   sum_n i j = i*((j - i)+1) + sum_n 0 ((j - i)).
 proof -strict.
   intros Hle;rewrite {1} (_: j=i+(j-i));first smt.
   cut: 0 <= (j-i) by smt.
   elim/Int.Induction.induction (j-i)=> //=.
     by rewrite !sum_n_ii.
   intros {j Hle} j Hj; rewrite -CommutativeGroup.Assoc sum_n_ij1;smt.
 qed.

 lemma sumn_ij (i j:int) : i <= j =>
   sum_n i j = i*((j - i)+1) + (j-i)*(j-i+1)/%2.
 proof -strict.
   intros Hle; rewrite sumn_ij_aux //;smt.
 qed.

import FSet.Interval.

 lemma sumn_pos (i j:int) : 0 <= i => 0 <= sum_n i j.
 proof -strict.
   case (i <= j) => Hle Hp.
     rewrite sumn_ij => //;smt all.
   by rewrite /sum_n sum_ij_gt; first smt.
 qed.

 lemma sumn_le (i j k:int) : i <= j =>  0 <= j => j <= k =>
   sum_n i j <= sum_n i k.    
 proof -strict.
   intros Hij H0j Hjk;rewrite /sum_n /sum_ij.
   cut -> :interval i k = FSet.union (interval i j) (interval (j+1) k).
     by apply FSet.set_ext => x;rewrite FSet.mem_union ?mem_interval;smt.
   rewrite sum_disj.
     by rewrite FSet.disjoint_spec => x;rewrite ?mem_interval;smt.
   smt.
 qed.
   
end Miplus.
  
(* For real *)
require Real.

clone Comoid as Mrplus with
   type Base.t <- real,
   op Base.(+) <- Real.(+),
   op Base.Z   <- 0%r,
   op NatMul.( * ) = fun n, (Real.( * ) (n%r))
  proof Base.* by smt, NatMul.* by smt.
import Int.  
import Real.

lemma NatMul_mul : forall (n:int) (r:real), 0 <= n => 
    Mrplus.NatMul.( * ) n r = n%r * r.
proof.    
  move => n r;elim /Int.Induction.induction n;smt.
qed.

require import FSet.
require import Distr.

pred disj_or (X:('a->bool) set) =
  forall x1 x2, x1 <> x2 => mem x1 X => mem x2 X =>
  forall a, x1 a => !(x2 a).

lemma or_exists (f:'a->bool) s:
  (Mbor.sum f s) <=> (exists x, (mem x s /\ f x)).
proof -strict.
  split;last by intros=> [x [x_in_s f_x]]; rewrite (Mbor.sum_rm _ _ x) // f_x.
  intros=> sum_true; pose p := fun x, mem x s /\ f x; change (exists x, p x);
    apply ex_for; delta p=> {p}; generalize sum_true; apply absurd=> /= h.
  cut := FSet.leq_refl s; pose {1 3} s' := s;elim/set_ind s'.
    by rewrite Mbor.sum_empty.
  intros=> x s' nmem IH leq_adds'_s.
  cut leq_s'_s : s' <= s.
    by apply (FSet.leq_tran (add x s'))=> //; apply leq_add.
  rewrite Mbor.sum_add // -nor IH // /=; cut := h x; rewrite -nand.
  by case (mem x s)=> //=; cut := leq_adds'_s x; rewrite mem_add //= => ->.
qed.

pred cpOrs (X:(('a -> bool)) set) (x:'a) = Mbor.sum (fun (P:('a -> bool)), P x) X.

lemma cpOrs0 : cpOrs (empty <:('a -> bool)>) = pred0.
proof -strict.
  by apply fun_ext => y;rewrite /cpOrs Mbor.sum_empty.
qed.

lemma cpOrs_add s (p:('a -> bool)) : 
  cpOrs (FSet.add p s) = (predU p (cpOrs s)).
proof -strict.
  apply fun_ext => y.
  rewrite /cpOrs /predU /= !or_exists eq_iff;split=> /=.
    intros [x ];rewrite FSet.mem_add => [ [ ] H H0];first by right;exists x.
    by left;rewrite -H.
  intros [H | [x [H1 H2]]];first by exists p;rewrite FSet.mem_add.
  by exists x; rewrite FSet.mem_add;progress;left.
qed.

lemma mu_ors d (X:('a->bool) set):
  disj_or X =>
  mu d (cpOrs X) = Mrplus.sum (fun P, mu d P) X.
proof strict.
  elim/set_ind X.
    by intros disj;rewrite Mrplus.sum_empty cpOrs0 mu_false.
  intros f X f_nin_X IH disj; rewrite Mrplus.sum_add // cpOrs_add mu_disjoint.
    rewrite /predI /pred0=> x' /=;
      rewrite -not_def=> [f_x']; generalize f_nin_X disj .
    rewrite /cpOrs or_exists => Hnm Hd [p [Hp]] /=.
    by apply (Hd f p) => //; smt.
  rewrite IH => //.
  by intros x y H1 H2 H3;apply (disj x y) => //;smt.
qed.

require ISet.
import Real.

lemma mean (d:'a distr) (p:'a -> bool):
  ISet.Finite.finite (ISet.create (support d)) =>
  mu d p = 
    Mrplus.sum (fun x, (mu_x d x)*(charfun p x))
        (ISet.Finite.toFSet (ISet.create (support d))).
proof strict.
  intros=> fin_supp_d.
  pose sup := ISet.Finite.toFSet (ISet.create (support d)).
  pose is  := img (fun x y, p x /\ x = y) sup.
  rewrite mu_support (mu_eq d _ (cpOrs is)).
    intros y;rewrite /predI /is /cpOrs /= or_exists eq_iff;split.
      intros [H1 H2];exists ((fun x0 y0, p x0 /\ x0 = y0) y);split => //.
      by apply mem_img;rewrite /sup ISet.Finite.mem_toFSet // ISet.mem_create.
    intros [p' []].
    by rewrite img_def; progress => //;smt. 
  rewrite mu_ors.
    rewrite /is => x1 x2 Hx; rewrite !img_def => [y1 [Heq1 Hm1]] [y2 [Heq2 Hm2]].
    subst; generalize Hx => /= Hx a [Hpa1 Heq1];rewrite -not_def => [Hpa2 Heq2].
    by subst; generalize Hx;rewrite not_def.
  rewrite /is => {is};elim/set_ind sup.
    by rewrite img_empty !Mrplus.sum_empty.
  intros x s Hnm Hrec;rewrite FSet.img_add Mrplus.sum_add0.
    rewrite img_def /= => [x0 [H1 H2]].
    by rewrite (mu_eq d _ pred0) //;smt.
  rewrite Mrplus.sum_add // -Hrec /=; congr => //.
  rewrite /charfun /mu_x;case (p x) => //= Hp.
    by apply mu_eq; rewrite pred1E.
  by rewrite -(mu_false d);apply mu_eq.
qed.
back to top