https://github.com/EasyCrypt/easycrypt
Raw File
Tip revision: 3f4a0bd5596888cd8d28b97687d477942187aa5f authored by Pierre-Yves Strub on 11 June 2022, 06:10:21 UTC
In loop fusion/fission, add more constraints on the epilog
Tip revision: 3f4a0bd
Matrix.eca
(* -------------------------------------------------------------------- *)
require import AllCore List Distr Perms.
require (*--*) Subtype Monoid Ring Subtype Bigalg StdOrder StdBigop.

import StdOrder.IntOrder StdBigop.Bigreal.

(* -------------------------------------------------------------------- *)
clone import Ring.IDomain as ZR.

(* -------------------------------------------------------------------- *)
clone import Bigalg.BigComRing as Big with theory CR <- ZR.

type R = t.

import BAdd.

(* -------------------------------------------------------------------- *)
op size : { int | 0 <= size } as ge0_size.

hint exact : ge0_size.

(* -------------------------------------------------------------------- *)
type vector.

theory Vector.
op tofunv : vector -> (int -> R).
op offunv : (int -> R) -> vector.

op prevector (f : int -> R) =
  forall i, !(0 <= i < size) => f i = zeror.

op vclamp (v : int -> R) : int -> R =
  fun i => if 0 <= i < size then v i else zeror.

axiom tofunv_prevector (v : vector) : prevector (tofunv v).
axiom tofunvK : cancel tofunv offunv.
axiom offunvK : forall v, tofunv (offunv v) = vclamp v.

op "_.[_]" (v : vector) (i : int) = tofunv v i.

lemma offunvE (v : int -> R) (i : int) : 0 <= i < size =>
  (offunv v).[i] = v i.
proof. by rewrite /("_.[_]") offunvK /vclamp => ->. qed.

lemma getv_out (v : vector) (i : int) :
  !(0 <= i < size) => v.[i] = zeror.
proof. by move/(tofunv_prevector v). qed.

lemma eq_vectorP (v1 v2 : vector) :
  (v1 = v2) <=> (forall i, 0 <= i < size => v1.[i] = v2.[i]).
proof. split=> [->//|eq_vi].
have: vclamp (tofunv v1) = vclamp (tofunv v2).
+ by apply/fun_ext=> j @/vclamp; case: (0 <= j < size); try apply: eq_vi.
by rewrite -!offunvK !tofunvK => /(congr1 offunv); rewrite !tofunvK.
qed.

lemma vectorW (P : vector -> bool) :
  (forall f, prevector f => P (offunv f)) =>
    forall v, P v.
proof. by move=> ih v; rewrite -(tofunvK v) &(ih) tofunv_prevector. qed.

op vectc c = offunv (fun _ => c).

abbrev zerov = vectc zeror.

lemma offunCE c i : 0 <= i < size => (vectc c).[i] = c.
proof. by move=> ?; rewrite offunvE. qed.

lemma offun0E i : zerov.[i] = zeror.
proof. by rewrite /("_.[_]") offunvK /vclamp if_same. qed.

hint simplify offun0E, offunCE.

op (+) (v1 v2 : vector) = offunv (fun i => v1.[i] + v2.[i]).
op [-] (v     : vector) = offunv (fun i => -v.[i]).

lemma offunD (v1 v2 : vector) i :
  (v1 + v2).[i] = v1.[i] + v2.[i].
proof.
case: (0 <= i < size) => [|/getv_out vi_eq0].
+ by move=> rg_i; rewrite offunvE.
+ by rewrite !vi_eq0 add0r.
qed.

hint simplify offunD.

lemma offunN (v : vector) i : (-v).[i] = - v.[i].
proof.
case: (0 <= i < size) => [|/getv_out vi_eq0]; last by rewrite !vi_eq0 oppr0.
by move=> rg_i; rewrite offunvE.
qed.

hint simplify offunN.

clone import Ring.ZModule with type t <- vector,
  op zeror <- Vector.zerov,
  op (+)   <- Vector.( + ),
  op [-]   <- Vector.([-])
proof *.

realize addrA. proof.
move=> /vectorW v1 h1 /vectorW v2 h2 /vectorW v3 h3.
by apply/eq_vectorP=> i hi; do! rewrite offunvE //=; rewrite addrA.
qed.

realize addrC. proof.
move=> /vectorW v1 h1 /vectorW v2 h2; apply/eq_vectorP.
by move=> i hi /=; rewrite addrC.
qed.

realize add0r. proof.
by move=> /vectorW f hf; apply/eq_vectorP=> i hi /=; rewrite add0r.
qed.

realize addNr. proof.
by move=> /vectorW f hf; apply/eq_vectorP=> i hi /=; rewrite addNr.
qed.

lemma offunB (v1 v2 : vector) i : (v1 - v2).[i] = v1.[i] - v2.[i].
proof. done. qed.

op dotp (v1 v2 : vector) =
  bigi predT (fun i => v1.[i] * v2.[i]) 0 size.

lemma dotpC : commutative dotp.
proof.
by move=> v1 v2; apply: eq_bigr=> i _ /=; rewrite mulrC.
qed.

lemma dotpDr v1 v2 v3 : dotp v1 (v2 + v3) = dotp v1 v2 + dotp v1 v3.
proof.
by rewrite /dotp -big_split &(eq_bigr) => i _ /=; rewrite mulrDr.
qed.

lemma dotpDl v1 v2 v3 : dotp (v1 + v2) v3 = dotp v1 v3 + dotp v2 v3.
proof. by rewrite dotpC dotpDr !(dotpC v3). qed.
end Vector.

export Vector.

(* -------------------------------------------------------------------- *)
theory Matrix.
type matrix.

op tofunm : matrix -> (int -> int -> R).
op offunm : (int -> int -> R) -> matrix.

abbrev mrange (i j : int) =
  0 <= i < size /\ 0 <= j < size.

lemma nosmt mrangeL i j : mrange i j => 0 <= i < size.
proof. by case. qed.

lemma nosmt mrangeR i j : mrange i j => 0 <= j < size.
proof. by case. qed.

lemma nosmt mrangeC i j : mrange i j = mrange j i.
proof. by apply: andbC. qed.

op prematrix (f : int -> int -> R) =
  forall i j, !mrange i j => f i j = zeror.

op mclamp (m : int -> int -> R) : int -> int -> R =
  fun i j : int => if mrange i j  then m i j else zeror.

axiom tofunm_prematrix (m : matrix) : prematrix (tofunm m).
axiom tofunmK : cancel tofunm offunm.
axiom offunmK : forall m, tofunm (offunm m) = mclamp m.

op "_.[_]" (m : matrix) (ij : int * int) = tofunm m ij.`1 ij.`2.

lemma offunmE (m : int -> int -> R) (i j : int) :
  mrange i j => (offunm m).[i, j] = m i j.
proof. by rewrite /"_.[_]" offunmK /mclamp => -[-> ->]. qed.

lemma getm_out (m : matrix) (i j : int) :
  !mrange i j => m.[i, j] = zeror.
proof. by move/(tofunm_prematrix m). qed.

lemma getm_outL (m : matrix) (i j : int) :
  !(0 <= i < size) => m.[i, j] = zeror.
proof. by move=> Nrg; rewrite getm_out // negb_and Nrg. qed.

lemma getm_outR (m : matrix) (i j : int) :
  !(0 <= j < size) => m.[i, j] = zeror.
proof. by move=> Nrg; rewrite getm_out // negb_and Nrg. qed.

lemma eq_matrixP (m1 m2 : matrix) :
  (m1 = m2) <=> (forall i j, mrange i j => m1.[i, j] = m2.[i, j]).
proof. split=> [->//|eq_mi].
have: mclamp (tofunm m1) = mclamp (tofunm m2).
+ by apply/fun_ext2 => i j @/mclamp; case _: (mrange _ _) => //= /eq_mi.
by rewrite -!offunmK !tofunmK => /(congr1 offunm); rewrite !tofunmK.
qed.

lemma matrixW (P : matrix -> bool) :
  (forall f, prematrix f => P (offunm f)) =>
    forall v, P v.
proof. by move=> ih m; rewrite -(tofunmK m) &(ih) tofunm_prematrix. qed.

op matrixc (c : R) = offunm (fun _ _ => c).

op diagmx (v : vector) =
  offunm (fun i j => if i = j then v.[i] else zeror).

abbrev diagc (c : R) = diagmx (vectc c).

abbrev zerom = matrixc zeror.
abbrev onem  = diagc   oner .

lemma offunCE c i j : mrange i j => (matrixc c).[i, j] = c.
proof. by move=> -[??]; rewrite offunmE. qed.

lemma diagmxE v i j : (diagmx v).[i, j] = if i = j then v.[i] else zeror.
proof.
case: (mrange i j) => [/offunmE ->|rgN_ij] //=.
rewrite getm_out //; case: (i = j) => // ->>.
by rewrite getv_out //; apply: contra rgN_ij.
qed.

lemma offun0E i j : zerom.[i, j] = zeror.
proof. by rewrite /matrixc /"_.[_]" offunmK /mclamp if_same. qed.

lemma offun1E i j : mrange i j => onem.[i, j] =
  if i = j then oner else zeror.
proof.
by move=> rg; rewrite diagmxE offunvE // (mrangeL _ _ rg).
qed.

lemma offun1_neqE i j : i <> j => onem.[i, j] = zeror.
proof.
move=> ne_ij; case: (mrange i j) => [rg|/getm_out ->//].
by rewrite offun1E // ne_ij.
qed.

hint simplify offun0E, offunCE.

op (+) (m1 m2 : matrix) = offunm (fun i j => m1.[i, j] + m2.[i, j]).
op [-] (m     : matrix) = offunm (fun i j => -m.[i, j]).

lemma offunD (m1 m2 : matrix) i j :
  (m1 + m2).[i, j] = m1.[i, j] + m2.[i, j].
proof.
case: (mrange i j) => [|/getm_out mi_eq0].
+ by move=> rg_i; rewrite offunmE.
+ by rewrite !mi_eq0 addr0.
qed.

hint simplify offunD.

lemma offunN (m : matrix) i j : (-m).[i, j] = - m.[i, j].
proof.
case: (mrange i j) => [|/getm_out mi_eq0].
+ by move=> rg_i; rewrite offunmE.
+ by rewrite !mi_eq0 oppr0.
qed.

hint simplify offunN.

clone import Ring.ZModule with type t <- matrix,
  op zeror <- Matrix.zerom,
  op (+)   <- Matrix.( + ),
  op [-]   <- Matrix.([-])
proof *.

realize addrA. proof.
move=> /matrixW v1 h1 /matrixW v2 h2 /matrixW v3 h3.
by apply/eq_matrixP=> i j hij; do! rewrite offunmE //=; rewrite addrA.
qed.

realize addrC. proof.
move=> /matrixW v1 h1 /matrixW v2 h2; apply/eq_matrixP.
by move=> i j hij /=; rewrite addrC.
qed.

realize add0r. proof.
by move=> /matrixW f hf; apply/eq_matrixP=> i j hij /=; rewrite add0r.
qed.

realize addNr. proof.
by move=> /matrixW f hf; apply/eq_matrixP=> i j hij /=; rewrite addNr.
qed.

lemma offunB (m1 m2 : matrix) i j : (m1 - m2).[i, j] = m1.[i, j] - m2.[i, j].
proof. done. qed.

op trace (m : matrix) =
  bigi predT (fun i => m.[i, i]) 0 size.

op trmx (m : matrix) =
  offunm (fun i j => m.[j, i]).

lemma trmxE (m : matrix) i j : (trmx m).[i, j] = m.[j, i].
proof.
case: (mrange i j) => [rg_ij|]; first by rewrite offunmE.
by move=> h; rewrite !getm_out // mrangeC.
qed.

lemma trmxK (m : matrix) : trmx (trmx m) = m.
proof.
by apply/eq_matrixP => i j rg_ij; do! rewrite offunmE ?(mrangeC j) //=.
qed.

lemma trmx1 : trmx onem = onem.
proof.
apply/eq_matrixP=> i j rgi; rewrite trmxE !offun1E 1:mrangeC //.
by rewrite (eq_sym j i).
qed.

lemma trmxD (m1 m2 : matrix) : trmx (m1 + m2) = trmx m1 + trmx m2.
proof. by apply/eq_matrixP=> i j rg; rewrite /= !trmxE. qed.

lemma trace_trmx m : trace (trmx m) = trace m.
proof. by apply: eq_bigr=> /= i _; rewrite trmxE. qed.

op ( * ) (m1 m2 : matrix) =
  offunm (fun i j => bigi predT (fun k => m1.[i, k] * m2.[k, j]) 0 size).

lemma offunM (m1 m2 : matrix) i j :
  (m1 * m2).[i, j] = bigi predT (fun k => m1.[i, k] * m2.[k, j]) 0 size.
proof.
case: (mrange i j) => [rg_ij|Nrg_ij]; first by rewrite offunmE.
rewrite getm_out // big1 // => /= k _; move: Nrg_ij.
by rewrite negb_and => -[/getm_outL|/getm_outR] ->; rewrite ?(mulr0, mul0r).
qed.

hint simplify offunM.

lemma mulmx1 : right_id onem Matrix.( * ).
proof.
move=> m; apply/eq_matrixP=> i j [rg_i rg_j] /=.
rewrite (bigD1 _ _ j) ?(mem_range, range_uniq) //=.
rewrite offun1E //= mulr1 big1 ?addr0 //= => k @/predC1 ne_kj.
by rewrite offun1_neqE // mulr0.
qed.

lemma mul1mx : left_id onem Matrix.( * ).
proof.
move=> m; apply/eq_matrixP=> i j [rg_i rg_j] /=.
rewrite (bigD1 _ _ i) ?(mem_range, range_uniq) //=.
rewrite offun1E //= mul1r big1 ?addr0 //= => k @/predC1 ne_kj.
by rewrite offun1_neqE 1:eq_sym // mul0r.
qed.

lemma mulmxDl (m1 m2 m : matrix) :
  (m1 + m2) * m = (m1 * m) + (m2 * m).
proof.
apply/eq_matrixP=> i j rg_ij /=; rewrite -big_split /=.
by apply/eq_bigr => k _ /=; rewrite mulrDl.
qed.

lemma mulmxDr (m1 m2 m : matrix) :
  m * (m1 + m2) = (m * m1) + (m * m2).
proof.
apply/eq_matrixP=> i j rg_ij /=; rewrite -big_split /=.
by apply/eq_bigr => k _ /=; rewrite mulrDr.
qed.

lemma mulmxA : associative Matrix.( * ).
proof.
move=> m1 m2 m3; apply/eq_matrixP=> i j rg_ij.
pose M k l := m1.[i, k] * m2.[k, l] * m3.[l, j].
pose E := bigi predT (fun k => bigi predT (M k) 0 size) 0 size.
apply: (@eq_trans _ E); rewrite offunM.
+ apply: eq_bigr=> /= k _; rewrite mulr_sumr &(eq_bigr) /=.
  by move=> l _; rewrite mulrA.
+ rewrite /E exchange_big &(eq_bigr) => /= l _.
  by rewrite mulr_suml &(eq_bigr) => /= k _.
qed.

lemma trmxM (m1 m2 : matrix) : trmx (m1 * m2) = trmx m2 * trmx m1.
proof.
apply/eq_matrixP=> i j rg; rewrite trmxE /=.
by apply: eq_bigr => /= k _; rewrite !trmxE mulrC.
qed.

op ( *^ ) (m : matrix) (v : vector) : vector =
  offunv (fun i => bigi predT (fun j => m.[i, j] * v.[j]) 0 size).

op ( ^* ) (v : vector) (m : matrix) : vector =
  offunv (fun j => bigi predT (fun i => v.[i] * m.[i, j]) 0 size).

lemma mulmxTv (m : matrix) (v : vector) : (trmx m) *^ v = v ^* m.
proof.
apply/eq_vectorP=> i rgi; rewrite !offunvE //= &(eq_bigr) /=.
by move=> j _; rewrite mulrC trmxE.
qed.

lemma mulmxv0 (m : matrix) : m *^ zerov = zerov.
proof.
apply/eq_vectorP => i rgi; rewrite !offunvE //=.
by rewrite big1 // => /= j _; rewrite mulr0.
qed.

lemma mulmx1v (v : vector) : onem *^ v = v.
proof.
apply/eq_vectorP => i rgi; rewrite !offunvE //=.
rewrite (bigD1 _ _ i) ?(mem_range, range_uniq) //=.
rewrite offun1E //= mul1r big1 ?addr0 //= => j @/predC1.
by move=> ne_ji; rewrite offun1_neqE // ?mul0r // eq_sym.
qed.

lemma mulmxvDl (m1 m2 : matrix) (v : vector) :
  (m1 + m2) *^ v = m1 *^ v + m2 *^ v.
proof.
apply/eq_vectorP=> i rgi /=; rewrite !offunvE //=.
by rewrite -big_split &(eq_bigr) => /= j _; rewrite mulrDl.
qed.

lemma mulmxvDr (m : matrix) (v1 v2 : vector) :
  m *^ (v1 + v2) = m *^ v1 + m *^ v2.
proof.
apply/eq_vectorP=> i rgi /=; rewrite !offunvE //=.
by rewrite -big_split &(eq_bigr) => /= j _; rewrite mulrDr.
qed.

lemma mulmxvA (m1 m2 : matrix) (v : vector) :
  (m1 * m2) *^ v = m1 *^ (m2 *^ v).
proof.
apply/eq_vectorP=> i rgi; rewrite !offunvE //=.
pose F j k := m1.[i,k] * m2.[k,j] * v.[j].
pose E := bigi predT (fun j => bigi predT (F j) 0 size) 0 size.
apply: (@eq_trans _ E).
+ by rewrite /E /F &(eq_bigr) => /= j _; rewrite mulr_suml.
rewrite /E /F exchange_big !big_seq &(eq_bigr) /= => j.
move/mem_range => rgj; rewrite offunvE //= mulr_sumr. 
by apply: eq_bigr => /= k _; rewrite mulrA.
qed.

lemma mulvmxT (v : vector) (m : matrix) : v ^* (trmx m) = m *^ v.
proof. by rewrite -{2}(trmxK m) eq_sym &(mulmxTv). qed.

lemma mulv0mx (m : matrix) : zerov ^* m = zerov.
proof. by rewrite -mulmxTv mulmxv0. qed.

lemma mulvmx1 (v : vector) : v ^* onem = v.
proof. by rewrite -mulmxTv trmx1 mulmx1v. qed.

lemma mulvmxDr (v : vector) (m1 m2 : matrix) :
  v ^* (m1 + m2) = v ^* m1 + v ^* m2.
proof. by rewrite -mulmxTv trmxD mulmxvDl !mulmxTv. qed.

lemma mulvmxDl (v1 v2 : vector) (m : matrix) :
  (v1 + v2) ^* m = v1 ^* m + v2 ^* m.
proof. by rewrite -mulmxTv mulmxvDr !mulmxTv. qed.

lemma mulvmxA (v : vector) (m1 m2 : matrix) :
  (v ^* m1) ^* m2 = v ^* (m1 * m2).
proof. by rewrite -!mulmxTv trmxM mulmxvA. qed.

lemma mulmxvE (m : matrix) (v : vector) i :
  (m *^ v).[i] = bigi predT (fun j => m.[i, j] * v.[j]) 0 size.
proof.
case: (0 <= i < size) => [rg_i|rgN_i]; first by rewrite offunvE.
by rewrite getv_out // big1 //= => j _; rewrite getm_outL // mul0r.
qed.

lemma mulvmxE (m : matrix) (v : vector) j :
  (v ^* m).[j] = bigi predT (fun i => v.[i] * m.[i, j]) 0 size.
proof.
by rewrite -mulmxTv mulmxvE; apply: eq_bigr => /= i _; rewrite trmxE mulrC.
qed.

op colmx (v : vector) : matrix = offunm (fun i _ => v.[i]).
op rowmx (v : vector) : matrix = offunm (fun _ j => v.[j]).

lemma colmxT (v : vector) : trmx (colmx v) = rowmx v.
proof.
apply: eq_matrixP => i j rg_ij; rewrite trmxE.
by rewrite !offunmE ?(mrangeC j i).
qed.

lemma rowmxT (v : vector) : trmx (rowmx v) = colmx v.
proof. by rewrite -colmxT trmxK. qed.

lemma colmxE (v : vector) i j : 0 <= j < size => (colmx v).[i, j] = v.[i].
proof.
move=> rg_j; case: (0 <= i < size) => [rg_i|rgN_i]; last first.
+ by rewrite getv_out // getm_outL.
+ by rewrite offunmE.
qed.

lemma rowmxE (v : vector) i j : 0 <= i < size => (rowmx v).[i, j] = v.[j].
proof. by move=> rg_i; rewrite -colmxT trmxE colmxE. qed.

lemma colmx_mulmxv (m : matrix) (v : vector) : colmx (m *^ v) = m * colmx v.
proof.
apply/eq_matrixP=> i j [rg_i rg_j]; rewrite colmxE // offunM.
by rewrite mulmxvE; apply: eq_bigr => /= k _; rewrite colmxE.
qed.

lemma rowmx_mulvmx (v : vector) (m : matrix) : rowmx (v ^* m) = rowmx v * m.
proof. by rewrite -colmxT -mulmxTv colmx_mulmxv trmxM trmxK colmxT. qed.

lemma mulmx_diag (v1 v2 : vector) :
  diagmx v1 * diagmx v2 = diagmx (offunv (fun i => v1.[i] * v2.[i])).
proof.
apply: eq_matrixP=> i j [rg_i rg_j]; rewrite diagmxE /=.
case: (i = j) => [<-|ne_ij]; last first.
+ rewrite big1 // => /= k _; rewrite !diagmxE.
  case: (i = k) => [->>|]; last by rewrite mul0r.
  by case: (k = j) => [<<-|]; last by rewrite mulr0.
rewrite offunvE //= (bigD1 _ _ i) ?(mem_range, range_uniq) //=.
rewrite !diagmxE /= big1 -1:addr0 //= => k @/predC1.
by rewrite !diagmxE => -> /=; rewrite mulr0.
qed.

lemma dotp_tr v1 v2 : dotp v1 v2 = trace (diagmx v1 * diagmx v2).
proof.
rewrite /trace /dotp !big_seq &(eq_bigr) => /= i /mem_range rg_i.
rewrite (bigD1 _ _ i) ?(mem_range, range_uniq) //=.
rewrite !diagmxE /= big1 -1:addr0 //= => @/predC1 j ne_ji.
by rewrite diagmxE (eq_sym i j) ne_ji mul0r.
qed.

lemma dotp_mulmxv m v1 v2 : dotp (m *^ v1) v2 = dotp v1 (v2 ^* m).
proof.
pose F i j := m.[i, j] * v1.[j] * v2.[i].
rewrite /dotp -(eq_bigr _ (fun i => bigi predT (F i) 0 size)) /=.
+ by move=> i _ @/F; rewrite -mulr_suml mulmxvE.
rewrite exchange_big &(eq_bigr) => /= i _ @/F.
rewrite mulvmxE mulr_sumr &(eq_bigr) => /= j _.
by rewrite mulrA (mulrC _ m.[_]) mulrA.
qed.

(* -------------------------------------------------------------------- *)
op dvector (d : R distr) =
  dmap (djoin (nseq size d)) (fun xs => offunv (nth witness xs)).

lemma dvector1E (d : R distr) (v : vector) : mu1 (dvector d) v =
  BRM.bigi predT (fun i => mu1 d v.[i]) 0 size.
proof.
pose g v := mkseq (tofunv v) size; rewrite (@dmap1E_can _ _ g).
+ move=> {v} @/g v; apply/eq_vectorP=> /= i rgi.
  by rewrite offunvE -1:nth_mkseq.
+ move=> xs xs_d @/g; rewrite offunvK -(eq_in_mkseq (nth witness xs)).
  * by move=> i rgi @/vclamp; rewrite rgi.
  rewrite (_ : size = size xs) -1:mkseq_nth //.
  by have := supp_djoin_size _ _ xs_d; rewrite size_nseq ler_maxr // eq_sym.
rewrite djoin1E size_nseq size_mkseq /= (BRM.big_nth witness) predTofV.
pose sz := size _; have szE: size = sz; first rewrite /sz.
+ by rewrite size_zip size_nseq size_mkseq minrE /= ler_maxr.
rewrite szE !BRM.big_seq &(BRM.eq_bigr) => /= i /mem_range rgi @/(\o); congr.
+ rewrite (nth_change_dfl (witness, witness)) // -szE //.
  by rewrite nth_zip ?(size_nseq, size_mkseq) //= nth_nseq // szE.
+ congr; rewrite (nth_change_dfl (witness, witness)) // -szE //.
  by rewrite nth_zip ?(size_nseq, size_mkseq) //= nth_mkseq // szE.
qed.

lemma dvector_uni d : is_uniform d => is_uniform (dvector d).
proof.
move=> uni_d; apply/dmap_uni_in_inj/djoin_uni.
- move=> xs ys /supp_djoin[+ _] /supp_djoin[+ _] /=.
  rewrite !size_nseq !ler_maxr // => ??.
  move/(congr1 tofunv); rewrite !offunvK => eqv.
  apply/(eq_from_nth witness) => [/#|i rg_i].
  by move/fun_ext/(_ i): eqv => /#.
- by move=> d'; rewrite mem_nseq => -[_ <-].
qed.

lemma dvector_ll d : is_lossless d => is_lossless (dvector d).
proof.
move=> ll_d; apply/dmap_ll/djoin_ll.
by move=> d'; rewrite mem_nseq => -[_ <-].
qed.

lemma dvector_fu d : is_full d => is_full (dvector d).
proof.
move=> full_d; apply/dmap_fu_in => v.
exists (map (fun i => v.[i]) (range 0 size)); split=> /=.
- apply/supp_djoin; rewrite size_nseq ler_maxr // size_map.
  rewrite size_range /= ler_maxr //=; apply/allP.
  case=> d' x /mem_zip []; rewrite mem_nseq => -[_ <<-] /= _.
  by apply: full_d.  
- apply/eq_vectorP=> i rg_i; rewrite offunvE //.
  by rewrite (nth_map witness) -1:nth_range // size_range ler_maxr.
qed.
 
lemma dvector_funi d : is_full d => is_uniform d =>
  is_funiform (dvector d).
proof.
move=> ??; apply/is_full_funiform => //.
- by apply/dvector_fu.
- by apply/dvector_uni.
qed.

(* -------------------------------------------------------------------- *)
op dmatrix (d : R distr) =
  dmap
    (djoin (nseq size (djoin (nseq size d))))
    (fun xs : R list list => offunm (fun i j =>
      nth witness (nth witness xs j) i
    )).

lemma dmatrix1E (d : R distr) (m : matrix) : mu1 (dmatrix d) m =
  BRM.bigi predT (fun i => BRM.bigi predT (fun j => mu1 d m.[i, j]) 0 size) 0 size.
proof.
pose g m := mkseq (fun j => mkseq (fun i => tofunm m i j) size) size.
rewrite BRM.exchange_big (@dmap1E_can _ _ g).
+ move=> {m} @/g m; apply/eq_matrixP=> /= i j rg.
  rewrite offunmE //= nth_mkseq 1:(mrangeR _ _ rg) /=.
  by rewrite nth_mkseq 1:(mrangeL _ _ rg).
+ move=> xs xs_d @/g; rewrite offunmK; have szE: size = size xs.
  * by rewrite (supp_djoin_size _ _ xs_d) size_nseq ler_maxr.
  rewrite szE -(eq_in_mkseq (nth witness xs)) -1:mkseq_nth //=.
  move=> i rgi; rewrite -(eq_in_mkseq (nth witness (nth witness xs i))).
  * by move=> j rgj; rewrite /mclamp szE rgi /= rgj.
  pose s := nth _ xs i; rewrite -{1}(mkseq_nth witness s).
  congr => @/s; rewrite -szE; move/supp_djoin: xs_d.
  case=> _ /allP /(_ (djoin (nseq size d), s) _) /=.
  * pose r := zip _ _; pose x := (_, _); have szrE: size r = size xs.
    - by rewrite size_zip size_nseq ler_maxr // szE minrE.
    suff ->: x = nth witness r i by apply: mem_nth; rewrite szrE.
    rewrite (nth_change_dfl (witness, witness)) 1:szrE //.
    rewrite nth_zip;first by rewrite size_nseq ler_maxr.
    by rewrite nth_nseq 1:szE.
  by move/supp_djoin_size => ->; rewrite size_nseq ler_maxr.
rewrite djoin1E size_nseq size_mkseq /= (BRM.big_nth witness) predTofV.
pose sz := size _; have szE: size = sz; first rewrite /sz.
+ by rewrite size_zip size_nseq size_mkseq minrE /= ler_maxr.
rewrite !szE !BRM.big_seq &(BRM.eq_bigr) => /= i /mem_range rgi @/(\o).
rewrite (nth_change_dfl (witness, witness)) // -szE //.
rewrite nth_zip /= ?(size_nseq, size_mkseq) //=.
rewrite nth_nseq 1:szE // nth_mkseq 1:szE //=.
rewrite djoin1E size_nseq size_mkseq /= (BRM.big_nth witness) predTofV.
pose sz' := size _; have sz'E: size = sz'; first rewrite /sz'.
+ by rewrite size_zip size_nseq size_mkseq minrE //= ler_maxr.
rewrite !sz'E !BRM.big_seq &(BRM.eq_bigr) => /= j /mem_range rgj @/(\o).
rewrite (nth_change_dfl (witness, witness)) // -sz'E //.
rewrite nth_zip ?(size_nseq, size_mkseq) //= nth_nseq 1:sz'E //.
by rewrite nth_mkseq 1:sz'E.
qed.

lemma dmatrix_dvector d : dmatrix d =
  dmap (djoin (nseq size (dvector d))) 
    (fun vs : vector list => offunm (fun i j => (nth witness vs j).[i])).
proof.
apply/eq_distr=> m; rewrite !dmap1E.
pose F (xs : R list list) :=
  offunm (fun i j => nth witness (nth witness xs j) i).
pose G (vs : vector list) :=
  offunm (fun i j => (nth witness vs j).[i]).
pose H (vs : vector list) :=
  map (fun v : vector => map (fun i => v.[i]) (range 0 size)) vs.
have eq: forall xs, size xs = size => G xs = (F \o H) xs.
- move=> xs eq_sz @/F @/G @/H @/(\o) /=; apply/eq_matrixP.
  move=> i j rg_ij; rewrite !offunmE //=.
  rewrite (nth_map witness) /= 1:/# (nth_map witness).
  - by rewrite size_range ler_maxr //#.
  by rewrite nth_range 1:/#.
rewrite -(@mu_eq_support _ (pred1 m \o (F \o H))).
- move=> xs /supp_djoin [+ _]; rewrite size_nseq ler_maxr //.
  by move/eq_sym => /eq @/(\o) <-.
rewrite -(dmap1E _ (F \o H)) -dmap_comp dmap1E; congr.
rewrite djoin_dmap_nseq; do 2! congr; rewrite /dvector.
rewrite dmap_comp /(\o) /= dmap_id_eq_in // => xs.
move=> /supp_djoin [+ _]; rewrite size_nseq ler_maxr //= => sz_xs.
apply/(eq_from_nth witness).
- by rewrite size_map size_range ler_maxr.
move=> i; rewrite size_map size_range ler_maxr //= => rg_i.
rewrite (nth_map witness) 1:!(size_range, ler_maxr) //.
by rewrite nth_range //= offunvE.
qed.

lemma dmatrix_uni d : is_uniform d => is_uniform (dmatrix d).
proof.
move=> uni_d; apply/dmap_uni_in_inj/djoin_uni.
- move=> xs ys /supp_djoin [+ xsin] /supp_djoin[+ ysin] /=.
  rewrite !size_nseq !ler_maxr // => sz_xs sz_ys.
  move/(congr1 tofunm); rewrite !offunmK => eqv.
  apply/(eq_from_nth witness) => [/#|j rg_j].
  have gt0_size: 0 < size by smt().
  pose xss := nth witness xs j; pose yss := nth witness ys j.
  pose ds  := djoin (nseq size d).
  move/allP/(_ (ds, xss) _): xsin => /= => [|xss_ds].
  - by rewrite {1}sz_xs; apply/mem_zip_nseqL/mem_nth.
  move/allP/(_ (ds, yss) _): ysin => /= => [|yss_ds].
  - by rewrite {1}sz_ys; apply/mem_zip_nseqL/mem_nth => /#.
  move/supp_djoin: xss_ds => [+ /allP ?].
  move/supp_djoin: yss_ds => [+ /allP ?].
  rewrite !size_nseq !ler_maxr // => sz_yss sz_xss.
  apply/(eq_from_nth witness) => [/#|i rg_i].
  move/fun_ext2/(_ i j): eqv => @/mclamp /=.
  rewrite -sz_xs in rg_j; rewrite -sz_xss in rg_i.
  by rewrite rg_i rg_j /=.
- move=> ds; rewrite mem_nseq => -[_ <-].
  by apply/djoin_uni=> ?; rewrite mem_nseq => -[_ <-].
qed.

lemma dmatrix_ll d : is_lossless d => is_lossless (dmatrix d).
proof.
move=> ll_d; apply/dmap_ll/djoin_ll.
move=> ds; rewrite mem_nseq => -[_ <-].
by apply/djoin_ll=> ?; rewrite mem_nseq => -[_ <-].
qed.

lemma dmatrix_fu d : is_full d => is_full (dmatrix d).
proof.
move=> full_d; apply/dmap_fu_in => m; pose ns := range 0 size.
exists (map (fun j => map (fun i => m.[i, j]) ns) ns); split=> /=.
- apply/supp_djoin; rewrite size_nseq ler_maxr // size_map.
  rewrite size_range /= ler_maxr //=; apply/allP.
  case=> d' x /mem_zip []; rewrite mem_nseq => -[_ <<-] /=.
  case/mapP=> /= i [/mem_range rg_i ->]; apply/djoin_fu.
  - by rewrite size_map size_nseq ler_maxr // size_range ler_maxr.
  - by move=> ? y; rewrite mem_nseq => -[_ <-]; apply: full_d.
- apply/eq_matrixP=> i j [rg_i rg_j]; rewrite offunmE //=.
  rewrite (nth_map witness) -1:nth_range //=.
  - by rewrite size_range ler_maxr.
  by rewrite (nth_map witness) -1:nth_range //= size_range ler_maxr.
qed.
 
lemma dmatrix_funi d : is_full d => is_uniform d =>
  is_funiform (dmatrix d).
proof.
move=> ??; apply/is_full_funiform => //.
- by apply/dmatrix_fu.
- by apply/dmatrix_uni.
qed.

end Matrix.

export Matrix.
back to top