https://github.com/EasyCrypt/easycrypt
Tip revision: 30bfa950afa3806948c073d3c9ec4468d33ea940 authored by Pierre-Yves Strub on 11 December 2023, 10:58:49 UTC
New tactic: "proc change"
New tactic: "proc change"
Tip revision: 30bfa95
DynMatrix.eca
(* This theory contains a formalization of vectors and matrices with *)
(* arbitrary and dynamic size. *)
(* Many operations conventionally have requirements on the sizes of their *)
(* operands. Matrix addition, for example, only makes sense if the matrices *)
(* have the same size. The user is responsible for making sure that the *)
(* parameters are valid. Many operators return the empty matrix/vector on *)
(* malformed input, but not all, so checking the output is not a reliable way *)
(* to check that the parameters are valid. *)
require import AllCore List Distr DBool DList.
import StdBigop.Bigreal StdOrder.IntOrder.
(* -------------------------------------------------------------------------- *)
require (*--*) Quotient StdOrder Bigalg.
(* -------------------------------------------------------------------------- *)
clone import Ring.ComRing as ZR.
type R = t.
(* -------------------------------------------------------------------------- *)
clone import Bigalg.BigComRing as Big with theory CR <- ZR proof *.
import BAdd.
(* -------------------------------------------------------------------------- *)
(* TOTHINK: This lemma is very useful in mulrxA, other uses? *)
lemma big_range0r (j' j i : int) P F : j' <= j =>
(forall x, j' <= x < j => P x => F x = zeror) =>
bigi P F i j = bigi P F i j'.
proof.
move => le_j'_j F0E; case (i <= j') => [le_i_j'|gt_i_j'].
- rewrite (big_cat_int j') 1,2:/# [bigi P F j' j]big1_seq ?addr0 //.
smt(mem_range).
rewrite !big1_seq; smt(mem_range).
qed.
theory Vectors.
type prevector = (int -> R) * int.
op vclamp (pv: prevector): prevector =
((fun i => if 0 <= i < pv.`2 then pv.`1 i else zeror), max 0 pv.`2).
lemma nosmt vclamp_idemp pv: vclamp (vclamp pv) = vclamp pv.
proof. rewrite /vclamp /#. qed.
op eqv (pv1 pv2: prevector) =
vclamp pv1 = vclamp pv2.
lemma nosmt eqv_vclamp pv: eqv pv (vclamp pv).
proof. by rewrite /eqv vclamp_idemp. qed.
clone import Quotient.EquivQuotient as QuotientVec with
type T <- prevector,
op eqv <- eqv
rename [type] "qT" as "vector"
proof * by smt().
type vector = QuotientVec.vector.
op tofunv v = vclamp (repr v).
op offunv pv = pi (vclamp pv).
lemma tofunvK: cancel tofunv offunv.
proof.
rewrite /tofunv /offunv /cancel => v; rewrite vclamp_idemp.
by rewrite -{2}[v]reprK -eqv_pi /eqv vclamp_idemp.
qed.
lemma offunvK pv: tofunv (offunv pv) = vclamp pv.
proof. by rewrite /tofunv /offunv eqv_repr vclamp_idemp. qed.
lemma vectorW (P : vector -> bool):
(forall pv, P (offunv pv)) => forall v, P v by smt(tofunvK).
(* Dimension of the vector *)
op size v = (tofunv v).`2.
lemma nosmt size_ge0 v: 0 <= size v by smt().
lemma nosmt max0size v: max 0 (size v) = size v by smt().
hint simplify max0size.
lemma offunv_max f n: offunv (f, max 0 n) = offunv (f, n).
proof. rewrite /offunv -eqv_pi /eqv /vclamp /#. qed.
lemma size_offunv f n: size (offunv (f, n)) = max 0 n.
proof. by rewrite /size offunvK /#. qed.
hint simplify size_offunv.
(* Getting the i-th element of a vector *)
op get (v: vector) (i: int) = (tofunv v).`1 i.
abbrev "_.[_]" (v: vector) (i : int) = get v i.
lemma get_offunv f n (i : int) : 0 <= i < n =>
(offunv (f, n)).[i] = f i.
proof. by rewrite /get /= offunvK /vclamp /= => ->. qed.
lemma getv0E v i: !(0 <= i < size v) => v.[i] = zeror by smt().
lemma offunv0E f n (i : int) : !(0 <= i < n) =>
(offunv (f, n)).[i] = zeror.
proof. move => ioutn. rewrite getv0E /= /#. qed.
lemma eq_vectorP (v1 v2 : vector) : (v1 = v2) <=>
(size v1 = size v2 /\ forall i, 0 <= i < size v1 => v1.[i] = v2.[i]).
proof.
split => [->//|[eq_size eq_vi]].
have: tofunv v1 = tofunv v2 by rewrite /tofunv /vclamp /#.
smt(tofunvK).
qed.
(* Constant valued vector of dimension n *)
op vectc n c = offunv ((fun _ => c), n).
lemma size_vectc n c: size (vectc n c) = max 0 n by done.
hint simplify size_vectc.
lemma get_vectc n c i: 0 <= i < n => (vectc n c).[i] = c by smt(get_offunv).
(* Zero-vector of dimension n *)
op zerov n = vectc n zeror.
lemma size_zerov n: size (zerov n) = max 0 n.
proof. by rewrite size_offunv. qed.
lemma get_zerov n i: (zerov n).[i] = zeror.
proof.
case (0 <= i < n) => i_bound; first by rewrite get_offunv.
by rewrite getv0E ?size_zerov /#.
qed.
hint simplify size_zerov, get_zerov.
(* The unique 0-size vector *)
op emptyv = zerov 0.
lemma size_emptyv: size emptyv = 0 by done.
lemma get_emptyv i: emptyv.[i] = zeror by done.
hint simplify size_emptyv, get_emptyv.
lemma emptyv_unique v: size v = 0 => v = emptyv.
proof. move => size_eq. apply eq_vectorP => /#. qed.
lemma size_vectc2 n c1 c2: vectc (size (vectc n c1)) c2 = vectc n c2.
proof. case (0 < n) => [/# | bd]. rewrite 2!(emptyv_unique (vectc _ _)) /#. qed.
lemma size_zerov2 n: zerov (size (zerov n)) = zerov n.
proof. rewrite /zerov. apply size_vectc2. qed.
(* lifting functions to vectors *)
op mapv (f : R -> R) (v : vector) : vector =
offunv (fun i => f v.[i], size v).
lemma size_mapv f v : size (mapv f v) = size v by [].
lemma get_mapv f (v : vector) i :
(0 <= i < size v) => (mapv f v).[i] = f v.[i].
proof. by move => ?; rewrite get_offunv. qed.
(* Vector addition *)
op[opaque] (+) (v1 v2 : vector) =
offunv ((fun i => v1.[i] + v2.[i]), max (size v1) (size v2)).
lemma size_addv v1 v2: size (v1 + v2) = max (size v1) (size v2).
proof. rewrite /(+) size_offunv /#. qed.
lemma get_addv (v1 v2 : vector) i: (v1 + v2).[i] = v1.[i] + v2.[i].
proof.
case (0 <= i < max (size v1) (size v2)) => bound.
- by rewrite /(+) /= get_offunv.
- rewrite !getv0E ?size_addv ?size_eq 4:addr0 /#.
qed.
(* Additive inverse of vector *)
op[opaque] [-] (v: vector) = offunv ((fun i => -v.[i]), size v).
lemma nosmt size_oppv v: size (-v) = size v by rewrite /([-]).
lemma getvN (v : vector) i : (-v).[i] = - v.[i].
proof.
case: (0 <= i < size v) => bound; 1: by rewrite /([-]) get_offunv.
by rewrite !getv0E /([-]) //= oppr0.
qed.
hint simplify getvN, size_oppv.
(* Simplifications for some special vectors *)
lemma oppv_zerov n: -(zerov n) = zerov n.
proof. apply eq_vectorP => /= i bound. by rewrite oppr0. qed.
lemma oppv_emptyv: -emptyv = emptyv by apply oppv_zerov.
hint simplify oppv_zerov, oppv_emptyv.
(* Module-like properties of vectors *)
lemma addvA: associative Vectors.(+).
proof.
move => v1 v2 v3; apply eq_vectorP.
rewrite !size_addv.
split => [/# | i _].
by rewrite !get_addv addrA.
qed.
lemma addvC: commutative Vectors.(+).
proof.
move => v1 v2; apply eq_vectorP.
rewrite !size_addv.
split => [/# | i _].
by rewrite !get_addv addrC.
qed.
lemma add0v v: zerov (size v) + v = v.
proof.
rewrite eq_vectorP size_addv /= => i bound.
by rewrite get_addv //= add0r.
qed.
hint simplify add0v.
lemma lin_add0v v n: n = size v => zerov n + v = v by done.
lemma addv0 v: v + zerov (size v) = v by rewrite addvC.
lemma lin_addv0 v n: n = size v => v + zerov n = v by move => ->; rewrite addv0.
abbrev (-) v1 v2 = v1 + (-v2).
lemma addvN v: v - v = zerov (size v).
proof.
rewrite eq_vectorP size_addv /= => i bound.
by rewrite get_addv //= addrN.
qed.
hint simplify addv0.
lemma oppvD (v1 v2: vector): -(v1 + v2) = -v1 + -v2.
proof.
apply eq_vectorP.
split => [| i _].
- by rewrite /= 2!size_addv /=.
by rewrite /= 2!get_addv opprD.
qed.
lemma oppvK (v: vector): - (- v) = v.
proof.
rewrite eq_vectorP.
split => [|i bound]; 1: by rewrite 2!size_oppv.
by rewrite 2!getvN opprK.
qed.
lemma sub_eqv (x y z : vector) :
size x = size z => size y = size z =>
x - z = y <=> x = y + z.
proof. smt(addvA addvC addvN addv0). qed.
(* Inner product *)
op[opaque] dotp (v1 v2 : vector) =
bigi predT (fun i => v1.[i] * v2.[i]) 0 (max (size v1) (size v2)).
lemma dotpE v1 v2: dotp v1 v2 =
bigi predT (fun i => v1.[i] * v2.[i]) 0 (max (size v1) (size v2)).
proof. by rewrite /dotp. qed.
lemma dotpC : commutative dotp.
proof.
move => v1 v2; rewrite 2!dotpE maxrC.
apply eq_bigr => i _ /=.
by rewrite mulrC.
qed.
lemma dotpNv v1 v2: dotp (-v1) v2 = - dotp v1 v2.
proof.
rewrite 2!dotpE /= sumrN.
apply eq_bigr => i _ /=.
by rewrite mulNr.
qed.
lemma dotpvN v1 v2: dotp v1 (-v2) = - dotp v1 v2 by rewrite dotpC dotpNv dotpC.
hint simplify dotpNv, dotpvN.
lemma dotpDr v1 v2 v3: dotp v1 (v2 + v3) = dotp v1 v2 + dotp v1 v3.
proof.
rewrite !dotpE size_addv; pose m := max _ (max _ _).
(* FIXME: (max _ _) should be precise enough, but this matches m! *)
rewrite -[bigi _ _ 0 (max _ (size v2))](big_range0r _ m); 1,2: smt(getv0E mul0r).
rewrite -[bigi _ _ 0 (max _ (size v3))](big_range0r _ m); 1,2: smt(getv0E mul0r).
by rewrite sumrD /= &(eq_bigr) /= => i _; rewrite get_addv mulrDr.
qed.
lemma dotpDl v1 v2 v3: dotp (v1 + v2) v3 = dotp v1 v3 + dotp v2 v3.
proof. by rewrite dotpC (dotpC v1) (dotpC v2) dotpDr. qed.
lemma dotp0v v n: dotp (zerov n) v = zeror.
proof.
rewrite dotpE big1 // => i _ /=.
by rewrite mul0r.
qed.
hint simplify dotp0v.
lemma dotpv0 v n: dotp v (zerov n) = zeror by rewrite dotpC.
hint simplify dotpv0.
(* Vector concatenation *)
op catv (v1 v2: vector) =
offunv ((fun i => v1.[i] + v2.[i-size v1]), size v1 + size v2).
abbrev ( || ) v1 v2 = catv v1 v2.
lemma size_catv (v1 v2: vector): size (v1 || v2) = (size v1 + size v2).
proof. rewrite /catv /= /#. qed.
lemma get_catv (v1 v2: vector) i :
(v1 || v2).[i] = if i < size v1 then v1.[i] else v2.[i - size v1].
proof.
case (0 <= i < size v1 + size v2) => range; last first.
- rewrite !getv0E ?size_catv //; smt(size_ge0).
rewrite get_offunv //=; case (i < size v1) => ?.
- by rewrite [v2.[_]]getv0E //= 1:/# addr0.
- by rewrite [v1.[_]]getv0E //= 1:/# add0r.
qed.
lemma get_catv' (v1 v2: vector) i: (v1 || v2).[i] = v1.[i] + v2.[i-size v1].
proof.
rewrite /catv.
case (0 <= i < size v1 + size v2) => range.
- rewrite get_offunv //=.
- rewrite !getv0E 4:addr0 /#.
qed.
lemma get_catv_l (v1 v2: vector) i :
i < size v1 => (v1 || v2).[i] = v1.[i].
proof. smt(get_catv). qed.
lemma get_catv_r (v1 v2: vector) i :
size v1 <= i => (v1 || v2).[i] = v2.[i - size v1].
proof. smt(get_catv). qed.
lemma catvA : associative Vectors.(||).
proof.
move => v1 v2 v3. rewrite eq_vectorP !size_catv addrA /= => i bound.
by rewrite !get_catv' size_catv opprD !addrA.
qed.
lemma dotp_catv v1 v2 v3 v4: size v1 = size v3 =>
dotp (v1 || v2) (v3 || v4) = (dotp v1 v3) + (dotp v2 v4).
proof.
move => size_eq; rewrite !dotpE !size_catv size_eq /=.
rewrite (range_cat (size v3)) 1:/# 1:/# big_cat; congr.
- by apply eq_big_seq => i /mem_range ? /=; smt(get_catv_l).
have ->:max (size v3 + size v2) (size v3 + size v4) =
size v3 + max (size v2) (size v4) by smt().
have {1}->: size v3 = 0 + size v3 by smt().
rewrite big_addn.
rewrite addrC addrA [_ + size v3]Ring.IntID.addrC subrr.
apply eq_big_seq => i /mem_range [l_bound u_bound] /=.
by rewrite !get_catv size_eq gtr_addr ler_gtF //= -addrA subrr.
qed.
lemma addv_catv (v1 v2 v3 v4: vector):
size v1 = size v3 =>
(v1 || v2) + (v3 || v4) = ((v1 + v3) || (v2 + v4)).
proof.
move => size_eq; apply eq_vectorP.
rewrite size_addv 3!size_catv 2!size_addv.
split => [/# | i bound].
rewrite get_addv 3!get_catv' 2!get_addv size_addv size_eq /max /=.
rewrite 2!addrA.
congr.
rewrite -2!addrA.
congr.
by rewrite addrC.
qed.
lemma oppv_catv (v1 v2: vector): - (v1 || v2) = (-v1 || -v2).
proof. by apply eq_vectorP => i bound /=; rewrite 2!get_catv' opprD. qed.
(* Splitting apart vectors *)
op subv (v: vector) (n m: int) = offunv ((fun i => v.[i+n]), m - n).
lemma size_subv (v: vector) (n m: int): size (subv v n m) = max 0 (m - n).
proof. by done. qed.
hint simplify size_subv.
lemma get_subv (v: vector) (n m i: int): 0 <= i < m - n =>
(subv v n m).[i] = v.[i+n].
proof. move => i_range; by rewrite /subv get_offunv. qed.
(* Interactions between subv and concatenation *)
lemma subv_catvCl v1 v2: subv (v1 || v2) 0 (size v1) = v1.
proof.
rewrite eq_vectorP size_subv /= => i bound.
by rewrite get_subv // get_catv' /= (getv0E v2) 1:/# addr0.
qed.
lemma subv_catvCr v1 v2: subv (v1 || v2) (size v1) (size v1 + size v2) = v2.
proof.
rewrite eq_vectorP size_subv.
split => [/# | i bound].
rewrite get_subv 1:/# get_catv' (getv0E v1) 1:/# add0r /#.
qed.
lemma catv_subvC v i: 0 <= i <= size v =>
(subv v 0 i || subv v i (size v)) = v.
proof.
move => i_bound; rewrite eq_vectorP size_catv !size_subv /= -andaE.
split => [/#|-> k k_bound]; case (k < i) => lt_i_k.
- rewrite get_catv_l; smt(get_subv).
- rewrite get_catv_r; smt(get_subv).
qed.
lemma subv_addv v1 v2 a b: size v1 = size v2 =>
subv (v1 + v2) a b = subv v1 a b + subv v2 a b.
proof.
move => size_eq; rewrite eq_vectorP ?size_addv /= => /= i bound.
rewrite get_subv 1:/# !get_addv // !get_subv /#.
qed.
(* Updating one entry in a vector *)
op updv (v: vector) (n: int) (m: t) =
offunv (fun i => if i = n then m else v.[i], size v).
lemma size_updv (v: vector) (n: int) (m: t): size (updv v n m) = size v by done.
lemma get_updv (v: vector) (n i: int) (m: t): (0 <= i < size v) =>
(updv v n m).[i] = if i = n then m else v.[i].
proof. rewrite /updv => i_bound. by rewrite get_offunv. qed.
(* Multiplying vector by a scalar *)
op scalarv (s: t) (v: vector) = offunv(fun i => s * v.[i], size v).
abbrev ( ** ) s v = scalarv s v.
lemma size_scalarv (v: vector) (s: t): size (s ** v) = size v by done.
lemma get_scalarv (v: vector) (s: t) (i: int): (s ** v).[i] = s * v.[i].
proof.
case (0 <= i < size v) => bound; 1: by rewrite get_offunv.
by rewrite offunv0E // getv0E // mulr0.
qed.
lemma scalar0v (v: vector): zeror ** v = zerov (size v).
proof.
rewrite eq_vectorP.
split => [| i bound]; 1: rewrite size_scalarv /#.
by rewrite get_zerov get_scalarv mul0r.
qed.
lemma scalar1v (v: vector): oner ** v = v.
proof.
rewrite eq_vectorP.
split => [| i bound]; 1: by rewrite size_scalarv.
by rewrite get_scalarv mul1r.
qed.
lemma scalarvN a (v : vector): a ** -v = - (a ** v).
proof.
rewrite eq_vectorP size_oppv 2!size_scalarv size_oppv /= => i bound.
by rewrite 2!get_scalarv getvN mulrN.
qed.
lemma scalarNv (a : R) (v : vector): (-a) ** v = - (a ** v).
proof.
rewrite eq_vectorP size_oppv 2!size_scalarv /= => i bound.
by rewrite !get_scalarv mulNr.
qed.
lemma dotp_scalarv (v1 v2: vector) (s1 s2: t):
dotp (s1 ** v1) (s2 ** v2) = (s1 * s2) * (dotp v1 v2).
proof.
rewrite 2!dotpE 2!size_scalarv !mulr_sumr &(eq_bigr) /= => i _.
rewrite 2!get_scalarv; smt(mulrA mulrC).
qed.
lemma dotp_scalarv_l (v1 v2: vector) (s : t):
dotp (s ** v1) v2 = s * dotp v1 v2.
proof. by rewrite -[v2]scalar1v dotp_scalarv scalar1v mulr1. qed.
lemma dotp_scalarv_r (v1 v2: vector) (s: t):
dotp v1 (s ** v2) = s * dotp v1 v2.
proof. by rewrite -[v1]scalar1v dotp_scalarv scalar1v mul1r. qed.
lemma scalarvDr s (v1 v2 : vector) :
s ** (v1 + v2) = s ** v1 + s ** v2.
proof.
apply eq_vectorP; rewrite 5!(size_addv,size_scalarv) /= => i Hi.
by rewrite !(get_scalarv,get_addv) mulrDr.
qed.
(* List-vector isomorphism *)
op oflist (s : R list): vector = offunv (nth witness s, size s).
lemma size_oflist l: size (oflist l) = size l.
proof. rewrite size_offunv; 1: smt(List.size_ge0). qed.
lemma get_oflist w (l: R list) i: 0 <= i < size l => (oflist l).[i] = nth w l i.
proof. by move => bound; rewrite get_offunv // (nth_change_dfl w). qed.
op tolist (v : vector): R list = map (fun i => v.[i]) (range 0 (size v)).
lemma size_tolist v: size (tolist v) = size v.
proof. rewrite size_map size_range /#. qed.
lemma nth_tolist w v i: 0 <= i < size v => nth w (tolist v) i = v.[i].
proof.
move => bound; rewrite (nth_map witness w) 1:size_range; 1: smt(List.size_ge0).
by rewrite /= nth_range.
qed.
lemma mem_tolist x v : x \in tolist v <=> exists i, 0 <= i < size v /\ x = v.[i].
proof. by rewrite mapP; smt(mem_range). qed.
lemma oflistK: cancel oflist tolist.
proof.
move => l; apply (eq_from_nth witness); rewrite size_tolist size_oflist //.
by move => i bound; rewrite nth_tolist 1:size_oflist // (get_oflist witness).
qed.
lemma tolistK: cancel tolist oflist.
proof.
move => l; apply eq_vectorP.
rewrite size_oflist size_tolist /= => i bound.
rewrite (get_oflist witness) 1:size_tolist // nth_tolist //.
qed.
lemma oflist_inj: injective oflist by smt(oflistK).
lemma tolist_inj: injective tolist by smt(tolistK).
lemma tolist_catv (v1 v2 : vector) :
tolist (v1 || v2) = (tolist v1) ++ (tolist v2).
proof.
rewrite /tolist size_catv (range_cat (size v1)); 1,2: smt(size_ge0).
rewrite map_cat; congr.
- apply/eq_in_map => i /mem_range /=; smt(get_catv_l).
rewrite addzC range_addr /= -map_comp /(\o).
apply/eq_in_map => i /mem_range /=; smt(get_catv_r).
qed.
lemma tolist_vectc n c : tolist (vectc n c) = nseq n c.
proof.
apply oflist_inj; rewrite tolistK; apply eq_vectorP.
rewrite size_vectc size_oflist size_nseq /= => i i_bound.
by rewrite get_vectc 1:/# (get_oflist c) ?size_nseq // ?nth_nseq /#.
qed.
(* Distribution of length n vectors sampled using d element-wise *)
op dvector (d: R distr) (n: int) = dmap (dlist d (max 0 n)) oflist.
lemma supp_dvector d v k :
0 <= k =>
(v \in dvector d k) <=>
size v = k /\ forall i, 0 <= i < k => v.[i] \in d.
proof.
move=> k_ge0; rewrite supp_dmap; split => [[s []]|[s_v] supp_v].
- rewrite ler_maxr // supp_dlist 1:/# => -[s_s /allP supp_s ->].
rewrite size_oflist s_s /= -s_s => i bound.
rewrite (get_oflist witness) // supp_s mem_nth //.
exists (tolist v); rewrite ler_maxr // tolistK /=.
apply/supp_dlist => //; rewrite size_tolist s_v /=; apply/allP => x x_v.
rewrite -(nth_index x _ _ x_v) nth_tolist.
- by rewrite index_ge0 -size_tolist index_mem x_v.
by rewrite supp_v index_ge0 -s_v -size_tolist index_mem.
qed.
lemma size_dvector d n v: v \in dvector d n => size v = max 0 n.
proof.
move => /supp_dmap[a [a_in ->]]; rewrite size_oflist.
rewrite supp_dlist /# in a_in.
qed.
lemma get_dvector0E (d: R distr) (v: vector) n: size v <> max 0 n =>
mu1 (dvector d n) v = 0%r.
proof.
move => size_ineq; apply supportPn.
rewrite -implybF => /size_dvector v_in.
by apply size_ineq.
qed.
lemma dvector1E (d : R distr) (v : vector) : mu1 (dvector d (size v)) v =
BRM.bigi predT (fun i => mu1 d v.[i]) 0 (size v).
proof.
rewrite -{2}[v]tolistK dmapE /(\o) /pred1.
rewrite (@mu_eq _ _ (pred1 (tolist v))); 1: smt(oflist_inj).
rewrite dlist1E 1:/# size_tolist max0size /=.
by rewrite BRM.big_mapT /(\o) &BRM.eq_big.
qed.
lemma mu1_dvector_split d i (v: vector): 0 <= i <= size v =>
mu1 (dvector d i) (subv v 0 i) *
mu1 (dvector d (size v - i)) (subv v i (size v)) =
mu1 (dvector d (size v)) v.
proof.
move => i_bound; have last_size: size v - i = size (subv v i (size v)).
- rewrite size_subv /#.
have first_size: i = size (subv v 0 i) by rewrite size_subv /#.
rewrite last_size {1}first_size !dvector1E.
rewrite (BRM.big_cat_int i 0 (size v)) 3:2!size_subv; first 2 smt().
have ->: max 0 (size v - i) = size v - i by smt().
have ->: max 0 i = i by smt().
congr.
- apply BRM.eq_big_int => k cont /=.
by rewrite get_subv.
- have ->: i = 0 + i by algebra.
rewrite (BRM.big_addn 0) /=.
apply BRM.eq_big_int => k cont /=.
by rewrite get_subv.
qed.
lemma dvector_uni d n: is_uniform d => is_uniform (dvector d n).
proof.
move => uni_d; apply dmap_uni; 1: apply/(can_inj _ tolist)/oflistK.
by apply dlist_uni.
qed.
lemma dvector_ll d n: is_lossless d => is_lossless (dvector d n).
proof. by move => ll_d; apply/dmap_ll/dlist_ll. qed.
lemma dvector_fu d (v: vector): is_full d => v \in dvector d (size v).
proof.
move=> full_d; rewrite supp_dmap.
exists (tolist v).
rewrite tolistK /= -size_tolist dlist_fu => x /#.
qed.
lemma mu1_dvector_fu (d: R distr) (v: vector): is_funiform d =>
mu1 (dvector d (size v)) v = (mu1 d witness)^(size v).
proof.
rewrite dvector1E => d_funi.
have ->: (fun (i : int) => mu1 d v.[i]) = fun (_: int) => mu1 d witness.
- rewrite fun_ext => i.
apply d_funi.
have: 0 <= size v by exact size_ge0.
move: (size v); elim/ge0ind => [/# | _ | n bound IH _].
- by rewrite range_geq // BRM.big_nil RField.expr0.
- by rewrite BRM.big_int_recr // RField.exprS // RField.mulrC IH.
qed.
end Vectors.
export Vectors.
(* -------------------------------------------------------------------------- *)
theory Matrices.
type prematrix = (int -> int -> R) * int * int.
op mclamp (pm: prematrix): prematrix =
((fun i j => if 0 <= i < pm.`2 /\ 0 <= j < pm.`3 then pm.`1 i j else zeror),
max 0 pm.`2, max 0 pm.`3).
lemma nosmt mclamp_idemp pm: mclamp (mclamp pm) = mclamp pm.
proof. by rewrite /mclamp /#. qed.
hint simplify mclamp_idemp.
op eqv (pm1 pm2: prematrix) = mclamp pm1 = mclamp pm2.
clone import Quotient.EquivQuotient as QuotientMat with
type T <- prematrix,
op eqv <- eqv
rename [type] "qT" as "matrix"
proof * by smt().
type matrix = QuotientMat.matrix.
op tofunm m = mclamp (repr m).
op offunm pm = pi (mclamp pm).
lemma tofunmK : cancel tofunm offunm.
proof.
rewrite /tofunm /offunm /cancel => m /=.
have ->: pi (mclamp (repr m)) = pi (repr m) by rewrite -eqv_pi /eqv.
apply reprK.
qed.
lemma offunmK pm: tofunm (offunm pm) = mclamp pm.
proof. by rewrite /tofunm /offunm eqv_repr. qed.
hint simplify offunmK.
lemma matrixW (P : matrix -> bool) : (forall pm, P (offunm pm)) =>
forall m, P m by smt(tofunmK).
(* Number of rows and columns of matrices *)
op rows m = (tofunm m).`2.
op cols m = (tofunm m).`3.
abbrev size m = (rows m, cols m).
lemma nosmt rows_ge0 m: 0 <= rows m by smt().
lemma nosmt cols_ge0 m: 0 <= cols m by smt().
lemma rows_offunm f r c: rows (offunm (f, r, c)) = max 0 r by done.
lemma cols_offunm f r c: cols (offunm (f, r, c)) = max 0 c by done.
lemma size_offunm f r c: size (offunm (f, r, c)) = (max 0 r, max 0 c) by done.
hint simplify rows_offunm, cols_offunm.
lemma nosmt max0rows m: max 0 (rows m) = rows m by smt().
lemma nosmt max0cols m: max 0 (cols m) = cols m by smt().
hint simplify max0rows, max0cols.
(* Getting the element at position i, j *)
op get (m : matrix) (ij : int * int) = (tofunm m).`1 ij.`1 ij.`2.
abbrev "_.[_]" m ij = get m ij.
abbrev mrange m (i j : int) = 0 <= i < rows m /\ 0 <= j < cols m.
lemma get_offunm f r c (i j : int) : mrange (offunm (f, r, c)) i j =>
(offunm (f, r, c)).[i, j] = f i j.
proof. rewrite /get /= /mclamp /= /#. qed.
lemma nosmt getm0E (m : matrix) (i j : int) : !mrange m i j => m.[i, j] = zeror.
proof. by smt(). qed.
lemma offunm0E f r c (i j: int) : !(0 <= i < r /\ 0 <= j < c) =>
(offunm (f, r, c)).[i, j] = zeror.
proof. move => idx_out. rewrite getm0E /#. qed.
lemma eq_matrixP (m1 m2 : matrix) : (m1 = m2) <=>
size m1 = size m2 /\ (forall i j, mrange m1 i j => m1.[i, j] = m2.[i, j]).
proof.
split=> [-> // | @/get /= eq_mi].
have: tofunm m1 = tofunm m2 by rewrite /tofunm /mclamp /#.
smt(tofunmK).
qed.
(* Special matrices *)
(* Constant valued matrix with r rows and c columns *)
op matrixc (rows cols: int) (c : R) = offunm ((fun _ _ => c), rows, cols).
lemma nosmt rows_matrixc cst r c: rows (matrixc r c cst) = max 0 r by done.
lemma nosmt cols_matrixc cst r c: cols (matrixc r c cst) = max 0 c by done.
hint simplify rows_matrixc, cols_matrixc.
lemma nosmt size_matrixc cst r c: size (matrixc r c cst) = (max 0 r, max 0 c).
proof. by done. qed.
lemma get_matrixc cst c r i j: mrange (matrixc r c cst) i j =>
(matrixc r c cst).[i, j] = cst.
proof. move => bound. by apply get_offunm. qed.
(* Matrix with the values of v on the diagonal and zeror off the diagonal *)
op diagmx (v : vector) =
offunm ((fun i j => if i = j then v.[i] else zeror), size v, size v).
lemma rows_diagmx v: rows (diagmx v) = size v by rewrite /diagmx /#.
lemma cols_diagmx v: cols (diagmx v) = size v by rewrite /diagmx /#.
hint simplify rows_diagmx, cols_diagmx.
lemma size_diagmx v: size (diagmx v) = (size v, size v) by done.
lemma get_diagmx v i j: (diagmx v).[i, j] = if i = j then v.[i] else zeror.
proof.
case (mrange (diagmx v) i j) => /= [[i_bound j_bound] | not_bound].
- by rewrite get_offunm.
rewrite getm0E //.
case (i = j) => [idx_eq | //].
rewrite getv0E /#.
qed.
hint simplify get_diagmx.
(* Matrix with constant values on the diagonal *)
abbrev diagc n (c : R) = diagmx (vectc n c).
lemma get_diagc n c i j:
(diagc n c).[i, j] = if i = j /\ 0 <= i < n then c else zeror.
proof.
case (i = j) => /= -> //=.
case (0 <= j < n) => j_bound; 1: by rewrite get_vectc.
apply getv0E => /= /#.
qed.
hint simplify get_diagc.
(* n by n identity matrix *)
abbrev onem n = diagc n oner.
lemma get_onem i j n: mrange (onem n) i j => (onem n).[i, j] =
if i = j then oner else zeror.
proof.
rewrite get_diagc => bound.
suff: (0 <= i < n) by move => ->.
rewrite rows_diagmx cols_diagmx /= in bound => /#.
qed.
lemma onem0E i j n: i <> j => (onem n).[i, j] = zeror.
proof. move=> ne_ij. by rewrite get_diagmx ne_ij. qed.
(* r by c zero matrix*)
abbrev zerom r c = matrixc r c zeror.
lemma get_zerom r c i j: (zerom r c).[i, j] = zeror.
proof.
case (mrange (zerom r c) i j) => range; last by rewrite getm0E.
by rewrite get_matrixc.
qed.
hint simplify get_zerom.
(* 0 by 0 matrix also used as error state when there is a size mismatch *)
op emptym = zerom 0 0.
lemma rows_emptym: rows emptym = 0 by done.
lemma cols_emptym: cols emptym = 0 by done.
lemma size_emptym: size emptym = (0, 0) by done.
lemma get_emptym i j: emptym.[i,j] = zeror by done.
hint simplify rows_emptym, cols_emptym, get_emptym.
lemma emptym_unique m: size m = (0, 0) => m = emptym.
proof.
move => [rows_m cols_m].
apply eq_matrixP => /=.
rewrite rows_m cols_m /= => i j /#.
qed.
(* lifting functions to matrices *)
op mapm (f : R -> R) (m : matrix) : matrix =
offunm (fun i j => f m.[i, j], rows m, cols m).
lemma size_mapm f v : size (mapv f v) = size v by [].
lemma get_mapm f (v : vector) i :
(0 <= i < size v) => (mapv f v).[i] = f v.[i].
proof. by move => bound; rewrite get_offunv. qed.
(* Matrix addition *)
op (+) (m1 m2 : matrix) =
offunm (fun i j => m1.[i, j] + m2.[i, j],
max (rows m1) (rows m2),
max (cols m1) (cols m2)).
lemma rows_addm (m1 m2: matrix): rows (m1 + m2) = max (rows m1) (rows m2).
proof. rewrite /(+) rows_offunm /#. qed.
lemma cols_addm (m1 m2: matrix): cols (m1 + m2) = max (cols m1) (cols m2).
proof. rewrite /(+) cols_offunm /#. qed.
lemma size_addm (m1 m2: matrix): size m1 = size m2 => size (m1 + m2) = size m1.
proof. move => size_eq; rewrite rows_addm cols_addm /#. qed.
lemma get_addm (m1 m2 : matrix) i j: (m1 + m2).[i, j] = m1.[i, j] + m2.[i, j].
proof.
case: (mrange (m1 + m2) i j) => rg_i.
- rewrite /(+) /= get_offunm 2:/#; 1:smt(rows_addm cols_addm).
- rewrite !getm0E; first 3 smt(rows_addm cols_addm).
by rewrite addr0.
qed.
(* Matrix additive inverse *)
op[opaque] [-] (m: matrix) = offunm (fun i j => -m.[i, j], rows m, cols m).
lemma rows_neg (m: matrix): rows (-m) = rows m by rewrite /([-]).
lemma cols_neg (m: matrix): cols (-m) = cols m by rewrite /([-]).
hint simplify rows_neg, cols_neg.
lemma size_neg (m: matrix): size (-m) = size m by rewrite /([-]).
lemma getmN (m : matrix) i j : (-m).[i, j] = - m.[i, j].
proof.
case: (mrange m i j) => rg_i; 1: by rewrite /([-]) get_offunm.
by rewrite getm0E // getm0E // oppr0.
qed.
hint simplify getmN.
lemma emptymN: -emptym = emptym by apply emptym_unique.
hint simplify emptymN.
(* Module like properties of matrices *)
lemma addmA: associative Matrices.( + ).
proof.
move => m1 m2 m3; apply eq_matrixP.
split => [| i j bound].
- rewrite !(rows_addm, cols_addm) 1:/#.
by rewrite 4!get_addm addrA.
qed.
lemma addmC: commutative Matrices.( + ).
proof.
move => m1 m2.
apply eq_matrixP.
rewrite !rows_addm !cols_addm.
split => [/# | i j _].
by rewrite 2!get_addm addrC.
qed.
lemma add0m m: zerom (rows m) (cols m) + m = m.
proof.
apply eq_matrixP.
rewrite size_addm /= 1://.
move => i j bound; rewrite get_addm //=.
apply add0r.
qed.
lemma lin_add0m m r c:
r = rows m => c = cols m => zerom r c + m = m.
proof. move => -> ->; apply add0m. qed.
lemma addm0 m: m + zerom (rows m) (cols m) = m by rewrite addmC add0m.
lemma lin_addm0 m r c:
r = rows m => c = cols m => m + zerom r c = m.
proof. move => -> ->; apply addm0. qed.
abbrev (-) (m1 m2: matrix) = m1 + (-m2).
lemma addmN (m: matrix) : m - m = zerom (rows m) (cols m).
proof.
apply eq_matrixP.
split => [| i j bound]; 1: by rewrite size_addm.
by rewrite get_addm //= addrN.
qed.
lemma addNm (m: matrix): (-m) + m = zerom (rows m) (cols m).
proof. by rewrite addmC addmN. qed.
lemma oppmD (m1 m2: matrix): -(m1 + m2) = -m1 + -m2.
proof.
rewrite eq_matrixP; split => [| i j bound].
- by rewrite /= 2!rows_addm 2!cols_addm.
- by rewrite /= 2!get_addm opprD.
qed.
lemma oppmK (m: matrix): - (- m) = m.
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: by rewrite !size_neg.
by rewrite !getmN opprK.
qed.
lemma sub_eqm (x y z : matrix):
size x = size z => size y = size z =>
x - z = y <=> x = y + z.
proof. smt(addmA addmC addNm add0m). qed.
(* matrix transposition *)
op trmx (m : matrix) = offunm (fun i j => m.[j, i], cols m, rows m).
lemma rows_tr m: rows (trmx m) = cols m by done.
lemma cols_tr m: cols (trmx m) = rows m by done.
hint simplify rows_tr, cols_tr.
lemma size_tr m: size (trmx m) = (cols m, rows m) by done.
lemma trmxE (m : matrix) i j : (trmx m).[i, j] = m.[j, i].
proof.
case: (mrange m j i) => bound.
- rewrite get_offunm /#.
- rewrite getm0E /#.
qed.
hint simplify trmxE.
lemma trmxK: involutive trmx.
proof.
move => m.
by apply eq_matrixP.
qed.
lemma trmx_inj: injective trmx by apply inv_inj; apply trmxK.
lemma trmxD (m1 m2 : matrix) : trmx (m1 + m2) = trmx m1 + trmx m2.
proof.
apply eq_matrixP.
split => [/= | i j bound]; 1: smt(rows_addm cols_addm cols_tr rows_tr).
rewrite trmxE 2!get_addm /#.
qed.
hint simplify trmxD.
lemma trmxN (m: matrix): trmx (-m) = - trmx m by apply eq_matrixP.
hint simplify trmxN.
lemma trmx_empty: trmx emptym = emptym by apply emptym_unique.
lemma trmx1 n: trmx (onem n) = (onem n).
proof.
apply eq_matrixP => i j bound /= /#.
qed.
hint simplify trmx1, trmx_empty.
lemma trmx_matrixc c n m: trmx (matrixc n m c) = matrixc m n c.
proof.
rewrite eq_matrixP /= => i j bound.
rewrite !get_matrixc /#.
qed.
hint simplify trmx_matrixc.
(* Gets the n-th row of m as a vector *)
op row m n = offunv (fun i => m.[n, i], cols m).
lemma size_row m i: size (row m i) = cols m by done.
lemma get_row m i j: (row m i).[j] = m.[i, j].
proof.
case (0 <= j < cols m) => bound; 1: by apply get_offunv.
rewrite getv0E // getm0E /#.
qed.
hint simplify size_row, get_row.
lemma row0E m i: !(0 <= i < rows m) => row m i = zerov (cols m).
proof. move => n_bound; rewrite eq_vectorP => j bd /=. rewrite getm0E /#. qed.
lemma rowN m i: - (row m i) = row (-m) i by rewrite eq_vectorP.
lemma rowD m1 m2 n: (row (m1 + m2) n) = row m1 n + row m2 n.
proof.
apply eq_vectorP.
rewrite size_row cols_addm size_addv 2!size_row /= => i _.
by rewrite get_addv get_addm /=.
qed.
lemma row_matrixc m n c i: 0 <= i < m => row (matrixc m n c) i = vectc n c.
proof.
move => i_bound; apply eq_vectorP => /= j j_bound.
rewrite get_vectc 1:/# get_matrixc /#.
qed.
(* Gets the n-th column of m as a vector *)
op col m n = offunv (fun i => m.[i, n], rows m).
lemma size_col m n: size (col m n) = rows m by done.
lemma get_col m i j: (col m j).[i] = m.[i, j].
proof.
case (0 <= i < rows m) => bound; 1: by apply get_offunv.
rewrite getv0E // getm0E /#.
qed.
hint simplify size_col, get_col.
lemma row_trmx m n: row (trmx m) n = col m n by rewrite eq_vectorP.
lemma col_trmx m n: col (trmx m) n = row m n by rewrite eq_vectorP.
hint simplify row_trmx, col_trmx.
lemma col0E m n: !(0 <= n < cols m) => col m n = zerov (rows m).
proof.
rewrite -trmxK cols_tr col_trmx => bound.
rewrite rows_tr.
by apply row0E.
qed.
lemma colD m1 m2 i: (col (m1 + m2) i) = col m1 i + col m2 i.
proof. smt(trmxK col_trmx trmxD rowD). qed.
lemma colN (m : matrix) i : - (col m i) = col (-m) i by rewrite eq_vectorP.
lemma col_matrixc m n c i: 0 <= i < n => col (matrixc m n c) i = vectc m c.
proof.
move => bound; rewrite -(trmxK (matrixc _ _ _)) col_trmx /=.
by apply row_matrixc.
qed.
(* Matrix multiplication *)
op[opaque] ( * ) (m1 m2 : matrix) =
offunm (fun i j => dotp (row m1 i) (col m2 j), rows m1, cols m2).
lemma rows_mulmx m1 m2: rows (m1 * m2) = rows m1 by rewrite /( * ).
lemma cols_mulmx m1 m2: cols (m1 * m2) = cols m2 by rewrite /( * ).
hint simplify rows_mulmx, cols_mulmx.
lemma size_mulmx m1 m2: size (m1 * m2) = (rows m1, cols m2) by done.
lemma get_mulmx m1 m2 i j: (m1 * m2).[i,j] = dotp (row m1 i) (col m2 j).
proof.
case (mrange (m1 * m2) i j) => /= [bound | /negb_and bound].
- by rewrite /( * ) /= get_offunm.
rewrite getm0E 1:/#.
elim bound => boundN.
- suff: row m1 i = zerov (cols m1) by move => ->.
rewrite row0E /#.
- suff: col m2 j = zerov (rows m2) by move => ->.
rewrite col0E /#.
qed.
lemma mulmx_emptym: emptym * emptym = emptym by apply/emptym_unique/size_mulmx.
hint simplify mulmx_emptym.
lemma trmxM (m1 m2 : matrix): trmx (m1 * m2) = trmx m2 * trmx m1.
proof.
apply eq_matrixP.
rewrite size_mulmx /= => i j bound.
by rewrite !get_mulmx dotpC.
qed.
hint simplify trmxM.
lemma mulmxDl (m1 m2 m : matrix): (m1 + m2) * m = (m1 * m) + (m2 * m).
proof.
apply eq_matrixP.
split => [| i j bound]; 1: smt(cols_addm rows_addm rows_mulmx cols_mulmx).
by rewrite get_addm 3!get_mulmx rowD dotpDl.
qed.
lemma mulmxDr (m1 m2 m : matrix): m * (m1 + m2) = (m * m1) + (m * m2).
proof. apply trmx_inj => /=. apply mulmxDl => /#. qed.
lemma mulmxA m1 m2 m3: m1 * (m2 * m3) = m1 * m2 * m3.
proof.
apply eq_matrixP.
split => [| i j bound]; 1: smt(rows_mulmx cols_mulmx).
rewrite !get_mulmx dotpE /= (big_range0r (rows m2)) 1:/#.
- smt(getm0E mulr0 mul0r rows_mulmx cols_mulmx row0E).
rewrite (eq_bigr _ _ (fun k =>
bigi predT (fun (l : int) => m1.[i,k] * (m2.[k,l] * m3.[l,j])) 0 (rows m3))).
- move => k /= _.
rewrite get_mulmx dotpE mulr_sumr /= (big_range0r (rows m3)) // 1:/#.
smt(getm0E mulr0 mul0r).
rewrite exchange_big /= dotpE /=.
rewrite [bigi _ _ _ (max _ _)](big_range0r (rows m3)) // 1:/#.
- smt(getm0E mulr0 mul0r).
apply eq_bigr => k _ /=; rewrite get_mulmx dotpE mulr_suml /=.
rewrite [bigi _ _ _ (max _ _)](big_range0r (rows m2)) 1:/#.
- smt(getm0E mulr0 mul0r).
by apply eq_bigr => l _ /=; rewrite mulrA.
qed.
lemma mulmxm0 n m: m * zerom (cols m) n = zerom (rows m) n.
proof.
rewrite eq_matrixP /= => i j /= bound.
by rewrite get_mulmx /= col_matrixc 1:/# dotpv0.
qed.
lemma mul0m n m: zerom n (rows m) * m = zerom n (cols m).
proof. apply trmx_inj => /=. apply mulmxm0. qed.
lemma mulmx1 m: m * onem (cols m) = m.
proof.
apply eq_matrixP.
rewrite size_mulmx 1:// /= => i j bound.
rewrite get_mulmx dotpE (bigD1 _ _ j) 1:mem_range 1:/# 1:range_uniq.
rewrite big1 /= => [k |]; last by rewrite get_vectc 1:/# mulr1 addr0.
rewrite /predC1 => -> /=.
by rewrite mulr0.
qed.
lemma mul1mx m: onem (rows m) * m = m by apply trmx_inj => /=; apply mulmx1.
(* Turns row vector into matrix *)
op rowmx (v: vector) = offunm ((fun _ i => v.[i]), 1, size v).
lemma rows_rowmx v: rows (rowmx v) = 1 by done.
lemma cols_rowmx v: cols (rowmx v) = size v by done.
hint simplify rows_rowmx, cols_rowmx.
lemma size_rowmx v: size (rowmx v) = (1, size v) by done.
lemma get_rowmx v i: (rowmx v).[0,i] = v.[i].
proof.
case (0 <= i < size v) => bound.
- rewrite get_offunm //= /#.
- by rewrite getm0E //= getv0E.
qed.
hint simplify get_rowmx.
lemma rowK v: row (rowmx v) 0 = v by rewrite eq_vectorP.
hint simplify rowK.
lemma rowmx_row m: rows m = 1 => rowmx (row m 0) = m.
proof.
move => rws; rewrite eq_matrixP /= rws => i j bound.
case (i = 0) => [->//| i_neq0].
rewrite !getm0E /#.
qed.
lemma rowmxD (v1 v2: vector): rowmx (v1 + v2) = rowmx v1 + rowmx v2.
proof.
rewrite -(rowmx_row (rowmx (v1 + v2))) 1://.
by rewrite -(rowmx_row (rowmx v1 + rowmx v2)) 1:rows_addm //= rowD.
qed.
lemma rowmxN (v : vector) : - (rowmx v) = rowmx (-v).
proof. rewrite eq_matrixP /= => i j; case (i = 0) => [->// | /#]. qed.
lemma rowmxc n c: rowmx (vectc n c) = matrixc 1 n c.
proof.
apply eq_matrixP => /= i j bound.
case (i = 0) => [->/=|i_neq0 /#].
rewrite get_matrixc /= 1:/# get_vectc /#.
qed.
hint simplify rowmxc.
(* turns column vector into matrix *)
op colmx (v: vector) = offunm ((fun i _ => v.[i]), size v, 1).
lemma rows_colmx v: rows (colmx v) = size v by done.
lemma cols_colmx v: cols (colmx v) = 1 by done.
hint simplify rows_colmx, cols_colmx.
lemma size_colmx v: size (colmx v) = (size v, 1) by done.
lemma get_colmx v i: (colmx v).[i,0] = v.[i].
proof.
case (0 <= i < size v) => bound; first rewrite get_offunm //= /#.
by rewrite getm0E //= getv0E.
qed.
hint simplify get_colmx.
lemma colK v: col (colmx v) 0 = v by rewrite eq_vectorP.
hint simplify colK.
lemma colmx_col m: cols m = 1 => colmx (col m 0) = m.
proof.
move => cls; rewrite eq_matrixP /= cls => i j bound.
case (j = 0) => [->// | /#].
qed.
lemma trmx_rowmx v: trmx (rowmx v) = colmx v.
proof.
rewrite eq_matrixP /= => i j bound.
case (j = 0) => [->// | /#].
qed.
hint simplify trmx_rowmx.
lemma trmx_colmx v: trmx (colmx v) = rowmx v.
proof. by apply trmx_inj => /=; rewrite trmxK. qed.
hint simplify trmx_colmx.
lemma rowmx_col m: cols m = 1 => rowmx (col m 0) = trmx m.
proof. move => rws; apply trmx_inj => /=. by rewrite trmxK colmx_col. qed.
lemma colmx_row m: rows m = 1 => colmx (row m 0) = trmx m.
proof. move => rws; apply trmx_inj => /=. by rewrite trmxK rowmx_row. qed.
lemma colmxD (v1 v2: vector):
colmx (v1 + v2) = colmx v1 + colmx v2.
proof. apply trmx_inj => /=. exact rowmxD. qed.
lemma colmxN (v : vector) : - (colmx v) = colmx (-v).
proof. rewrite eq_matrixP /= => i j; case (j = 0) => [->// | /#]. qed.
lemma colmxc n c: colmx (vectc n c) = matrixc n 1 c by apply trmx_inj.
hint simplify colmxc.
(* Matrix and vector multiplication *)
op[opaque] mulmxv m v = col (m * colmx v) 0.
abbrev ( *^ ) (m : matrix) (v : vector) : vector = mulmxv m v.
lemma mulmxvE m v: m *^ v = col (m * colmx v) 0 by rewrite /mulmxv.
lemma size_mulmxv m (v: vector): size (m *^ v) = rows m by rewrite /mulmxv.
lemma get_mulmxv m v i: (m *^ v).[i] = dotp (row m i) v.
proof. by rewrite mulmxvE /= get_mulmx. qed.
lemma colmx_mulmxv (m : matrix) (v : vector) :
colmx (m *^ v) = m * colmx v.
proof. by rewrite mulmxvE colmx_col. qed.
lemma mulmxvDl (m1 m2 : matrix) (v : vector) :
(m1 + m2) *^ v = m1 *^ v + m2 *^ v.
proof. by rewrite !mulmxvE -colD mulmxDl. qed.
lemma mulmxvDr (m : matrix) (v1 v2 : vector) :
m *^ (v1 + v2) = m *^ v1 + m *^ v2.
proof. by rewrite !mulmxvE -colD colmxD 1:// mulmxDr. qed.
lemma mulmxvA (m1 m2 : matrix) (v : vector) :
m1 *^ (m2 *^ v) = (m1 * m2) *^ v.
proof. by rewrite !mulmxvE colmx_col // mulmxA. qed.
lemma mulmxv0 (m : matrix) : m *^ (zerov (cols m)) = zerov (rows m).
proof. by rewrite mulmxvE colmxc /= mulmxm0 col_matrixc. qed.
lemma mulmx1v (v : vector): onem (size v) *^ v = v.
proof.
rewrite -{3}colK mulmxvE.
congr.
have ->: size v = rows (colmx v) by done.
by rewrite mul1mx.
qed.
lemma mul_colmxc (v:vector) c: (colmx v) *^ vectc 1 c = c ** v.
proof.
rewrite eq_vectorP size_scalarv size_mulmxv /= /max 1:// => i bound.
rewrite get_scalarv get_mulmxv dotpE /= /max /=.
by rewrite Big.BAdd.big_int1 /= get_vectc // mulrC.
qed.
lemma mulmx_scalarv (m : matrix) (s : t) (v : vector) :
m *^ (s ** v) = s ** (m *^ v).
proof.
apply eq_vectorP; rewrite size_mulmxv size_scalarv size_mulmxv => /= i Hi.
by rewrite get_scalarv !get_mulmxv dotp_scalarv_r.
qed.
(* Vector and matrix multiplication *)
op[opaque] mulvmx v m = row (rowmx v * m) 0.
abbrev ( ^* ) (v : vector) (m : matrix) : vector = mulvmx v m.
lemma mulvmxE v m: v ^* m = row (rowmx v * m) 0 by rewrite /mulvmx.
lemma size_mulvmx (v: vector) m: size (v ^* m) = cols m by rewrite /mulvmx.
lemma get_mulvmx v m i: (v ^* m).[i] = dotp v (col m i).
proof. by rewrite /mulvmx /= get_mulmx. qed.
lemma mulmxTv (m : matrix) (v : vector) : (trmx m) *^ v = v ^* m.
proof. by rewrite mulvmxE -col_trmx trmxM mulmxvE. qed.
hint simplify mulmxTv.
lemma mulvmxT (v : vector) (m : matrix) : v ^* (trmx m) = m *^ v.
proof. by rewrite mulvmxE-{2}(trmxK m) /#. qed.
hint simplify mulvmxT.
lemma mulvmxDr (v : vector) (m1 m2 : matrix) :
v ^* (m1 + m2) = v ^* m1 + v ^* m2.
proof. rewrite -mulmxTv trmxD mulmxvDl /#. qed.
lemma mulvmxDl (v1 v2 : vector) (m : matrix) :
(v1 + v2) ^* m = v1 ^* m + v2 ^* m.
proof. by rewrite -mulmxTv mulmxvDr. qed.
lemma mulvmxA (v : vector) (m1 m2 : matrix) :
v ^* (m1 * m2) = (v ^* m1) ^* m2.
proof. rewrite -(trmxK (m1 * m2)) trmxM mulvmxT -mulmxvA /#. qed.
lemma mulv0mx (m : matrix): zerov (rows m) ^* m = zerov (cols m).
proof. by rewrite -mulmxTv -{2}trmxK rows_tr mulmxv0 -cols_tr trmxK. qed.
lemma rowmx_mulvmx (v : vector) (m : matrix) :
rowmx (v ^* m) = rowmx v * m.
proof. by rewrite -trmx_colmx -mulmxTv colmx_mulmxv 1:// trmxM trmxK. qed.
lemma mulvmx1 (v : vector) : v ^* onem (size v) = v.
proof. by rewrite -mulmxTv trmx1 mulmx1v. qed.
lemma dotp_eqv_mul v1 v2 : dotp v1 v2 = (rowmx v1 * colmx v2).[0,0].
proof. by rewrite get_mulmx. qed.
lemma dotp_mulmxv m (v1 v2: vector): dotp v1 (m *^ v2) = dotp (v1 ^* m) v2.
proof.
by rewrite 2!dotp_eqv_mul mulmxvE colmx_col 1:// mulvmxE rowmx_row 1:// mulmxA.
qed.
lemma mul_rowmxc (v:vector) c: vectc 1 c ^* (rowmx v) = c ** v.
proof. by rewrite -(trmxK (rowmx v)) mulvmxT /= mul_colmxc. qed.
(* Sideways matrix concatenation - aka row block matrices *)
op catmr (m1 m2: matrix) =
offunm ((fun i j => m1.[i, j] + m2.[i, j-cols m1]),
max (rows m1) (rows m2), cols m1 + cols m2).
abbrev ( || ) m1 m2 = catmr m1 m2.
lemma rows_catmr (m1 m2: matrix): rows (m1 || m2) = max (rows m1) (rows m2).
proof. rewrite rows_offunm /#. qed.
lemma cols_catmr (m1 m2: matrix): cols (m1 || m2) = cols m1 + cols m2.
proof. rewrite cols_offunm /#. qed.
lemma size_catmr (m1 m2: matrix):
size (m1 || m2) = (max (rows m1) (rows m2), cols m1 + cols m2).
proof. rewrite rows_offunm cols_offunm /#. qed.
lemma get_catmr (m1 m2: matrix) i j:
(m1 || m2).[i, j] = m1.[i, j] + m2.[i, j-cols m1].
proof.
rewrite /catmr /=.
case (mrange (m1 || m2) i j) => range.
- rewrite get_offunm //.
- rewrite !getm0E /=; first 3 smt(size_catmr).
by rewrite addr0.
qed.
lemma col_catmrL m1 m2 i: rows m1 = rows m2 => i < cols m1 =>
col (m1 || m2) i = col m1 i.
proof.
move => rows_eq bound1; rewrite eq_vectorP /=.
split => [| j bound2]; 1: by rewrite rows_catmr /#.
rewrite get_catmr // (getm0E m2) 2:addr0; smt(rows_catmr).
qed.
lemma col_catmrR m1 m2 i: rows m1 = rows m2 => cols m1 <= i =>
col (m1 || m2) i = col m2 (i - cols m1).
proof.
move => rows_eq bound; rewrite eq_vectorP /=.
split => [| j bound2]; 1: by rewrite rows_catmr /#.
rewrite get_catmr // (getm0E m1) 2:add0r; smt(rows_catmr).
qed.
lemma row_catmr m1 m2 i: row (m1 || m2) i = (row m1 i || row m2 i).
proof.
rewrite eq_vectorP /=.
rewrite cols_catmr // size_catv /= => j bound.
by rewrite get_catmr // get_catv'.
qed.
lemma catmrA (m1 m2 m3: matrix): ((m1 || m2) || m3) = (m1 || (m2 || m3)).
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: smt(size_catmr).
rewrite 4!get_catmr cols_catmr // addrA.
algebra.
qed.
lemma catmrDr (m1 m2 m3: matrix): m1 * (m2 || m3) = ((m1 * m2) || (m1 * m3)).
proof.
rewrite eq_matrixP.
rewrite rows_mulmx cols_mulmx cols_catmr.
split => [| i j bound]; 1: smt(size_mulmx size_catmr).
rewrite get_catmr 3!get_mulmx.
case (j < cols m2) => bound2.
- rewrite [col m3 _]col0E 1:/# dotpv0 addr0 !dotpE 2!size_col rows_catmr.
rewrite (big_cat_int (max (cols m1) (rows m2))) 1:/# size_row 1:/#.
rewrite [bigi _ _ (max _ _) _]big1_seq => [k [_ /mem_range bound3] /=|].
+ by rewrite get_catmr [_ m2 _]getm0E 1:/# add0r [_ m3 _]getm0E 1:/# mulr0.
rewrite addr0.
apply eq_bigr => k _ /=.
by rewrite get_catmr [_ m3 _]getm0E 1:/# addr0.
- rewrite [col m2 _]col0E 1:/# dotpv0 add0r !dotpE 2!size_col rows_catmr.
rewrite (big_cat_int (max (cols m1) (rows m3))) 1:/# size_row 1:/#.
rewrite [bigi _ _ (max _ _) _]big1_seq => [k [_ /mem_range bound3] /=|].
+ by rewrite get_catmr [_ m3 _]getm0E 1:/# addr0 [_ m2 _]getm0E 1:/# mulr0.
rewrite addr0.
apply eq_bigr => k _ /=.
by rewrite get_catmr [_ m2 _]getm0E 1:/# add0r.
qed.
lemma rowmx_catmr (v1 v2: vector): rowmx (v1 || v2) = (rowmx v1 || rowmx v2).
proof.
apply eq_matrixP.
rewrite size_catmr 1:rows_rowmx 1://.
rewrite !rows_rowmx !cols_rowmx size_catv /= => i j bound.
rewrite get_catmr //.
have ->: (i = 0) by smt().
by rewrite !get_rowmx get_catv' cols_rowmx.
qed.
lemma addm_catmr (m1 m2 m3 m4: matrix): cols m1 = cols m3 =>
(m1 || m2) + (m3 || m4) = ((m1 + m3) || (m2 + m4)).
proof.
move => cols_eq.
rewrite eq_matrixP. rewrite size_catmr 3!rows_addm 3!cols_addm 2!rows_catmr.
rewrite 2!cols_catmr.
split => [/# | i j bound].
rewrite get_addm 3!get_catmr 2!get_addm cols_addm /max cols_eq /=.
smt(addrA addrC).
qed.
lemma oppm_catmr (m1 m2: matrix): - (m1 || m2) = (-m1 || -m2).
proof.
apply eq_matrixP => /=.
split => [// | i j bound].
by rewrite 2!get_catmr // /= opprD.
qed.
lemma mulmxv_cat (m1 m2 : matrix) (v1 v2 : vector): cols m1 = size v1 =>
(m1 || m2) *^ (v1 || v2) = (m1 *^ v1) + (m2 *^ v2).
proof.
move => cols_eq_size; apply eq_vectorP.
rewrite size_addv 3!size_mulmxv rows_catmr /= => i bound.
by rewrite get_addv 3!get_mulmxv row_catmr dotp_catv.
qed.
(* Downwards matrix concatenation - aka column block matrices *)
op catmc (m1 m2: matrix) =
offunm ((fun i j => m1.[i, j] + m2.[i-rows m1, j]),
rows m1 + rows m2, max (cols m1) (cols m2)).
abbrev ( / ) m1 m2 = catmc m1 m2.
lemma cols_catmc (m1 m2: matrix): cols (m1 / m2) = max (cols m1) (cols m2).
proof. rewrite cols_offunm /#. qed.
lemma rows_catmc (m1 m2: matrix): rows (m1 / m2) = rows m1 + rows m2.
proof. rewrite rows_offunm /#. qed.
lemma size_catmc (m1 m2: matrix):
size (m1 / m2) = (rows m1 + rows m2, max (cols m1) (cols m2)).
proof. rewrite cols_offunm rows_offunm /#. qed.
lemma get_catmc (m1 m2: matrix) i j:
(m1 / m2).[i, j] = m1.[i, j] + m2.[i-rows m1, j].
proof.
case (mrange (m1 / m2) i j) => range.
- rewrite get_offunm /=; smt(size_catmc).
- rewrite !getm0E /= 4:addr0; smt(size_catmc).
qed.
lemma catmcT (m1 m2: matrix): trmx (m1 / m2) = (trmx m1 || trmx m2).
proof.
rewrite eq_matrixP /= cols_catmc rows_catmc rows_catmr cols_catmr /= => i j bnd.
by rewrite get_catmr get_catmc.
qed.
hint simplify catmcT.
lemma catmrT (m1 m2: matrix): trmx (m1 || m2) = trmx m1 / trmx m2.
proof. apply trmx_inj. by rewrite /= 3!trmxK. qed.
hint simplify catmrT.
lemma row_catmcL m1 m2 i:
cols m1 = cols m2 => i < rows m1 => row (m1 / m2) i = row m1 i.
proof. move => cols_eq bound; by rewrite -col_trmx /= col_catmrL. qed.
lemma row_catmcR m1 m2 i: cols m1 = cols m2 => rows m1 <= i =>
row (m1 / m2) i = row m2 (i - rows m1).
proof. move => cols_eq bound; by rewrite -col_trmx /= col_catmrR. qed.
lemma col_catmc m1 m2 i:
col (m1 / m2) i = (col m1 i || col m2 i).
proof. by rewrite -row_trmx /= row_catmr. qed.
lemma catmcA (m1 m2 m3: matrix):
((m1 / m2) / m3) = (m1 / (m2 / m3)).
proof. apply trmx_inj => /=. exact catmrA. qed.
lemma catmcDl m1 m2 m3:
(m1 / m2) * m3 = (m1 * m3) / (m2 * m3).
proof. apply trmx_inj => /=. exact catmrDr. qed.
lemma colmx_catmr (v1 v2: vector): colmx (v1 || v2) = (colmx v1 / colmx v2).
proof. apply trmx_inj => /=. exact rowmx_catmr. qed.
lemma addm_catmc (m1 m2 m3 m4: matrix): rows m1 = rows m3 =>
(m1 / m2) + (m3 / m4) = ((m1 + m3) / (m2 + m4)).
proof.
move => cols_eq; apply trmx_inj => /=.
rewrite addm_catmr /#.
qed.
lemma oppm_catmc (m1 m2: matrix): - (m1 / m2) = ((-m1) / -m2).
proof. apply trmx_inj => /=. by apply oppm_catmr. qed.
(* Taking a submatrix from row r1 (inclusive) to r2 (exclusive) and
column c1 (inclusive) to c2 (exclusive).
That is m = subm 0 (rows m) 0 (cols m) *)
op subm (m: matrix) (r1 r2 c1 c2: int) =
offunm ((fun i j => m.[i+r1,j+c1]), r2-r1, c2-c1).
lemma rows_subm m r1 r2 c1 c2:
rows (subm m r1 r2 c1 c2) = max 0 (r2 - r1) by done.
lemma cols_subm m r1 r2 c1 c2:
cols (subm m r1 r2 c1 c2) = max 0 (c2 - c1) by done.
hint simplify rows_subm, cols_subm.
lemma size_subm m r1 r2 c1 c2:
size (subm m r1 r2 c1 c2) = (max 0 (r2 - r1), max 0 (c2 - c1)) by done.
lemma get_subm (m: matrix) (r1 r2 c1 c2 i j: int):
0 <= i < r2 - r1 => 0 <= j < c2 -c1 =>
(subm m r1 r2 c1 c2).[i,j] = m.[i+r1,j+c1].
proof. by move => i_range j_range; rewrite /subm get_offunm /#. qed.
lemma subm_id m: subm m 0 (rows m) 0 (cols m) = m.
proof. by rewrite eq_matrixP /= => i j bound; rewrite get_subm /#. qed.
lemma submT m r1 r2 c1 c2: trmx (subm m r1 r2 c1 c2) = subm (trmx m) c1 c2 r1 r2.
proof.
apply eq_matrixP => i j bound /=.
rewrite !get_subm; smt(size_tr size_subm).
qed.
hint simplify submT.
lemma subm_catmrCl m1 m2: subm (m1 || m2) 0 (rows m1) 0 (cols m1) = m1.
proof.
rewrite eq_matrixP => i j bound.
rewrite get_subm; first 2 smt(size_subm).
rewrite get_catmr // (getm0E m2) /= 2:addr0; smt(cols_subm).
qed.
lemma subm_catmcCl m1 m2: subm (m1 / m2) 0 (rows m1) 0 (cols m1) = m1.
proof. apply trmx_inj => /=. exact subm_catmrCl. qed.
lemma subm_catmrCr m1 m2:
subm (m1 || m2) 0 (rows m2) (cols m1) (cols m1 + cols m2) = m2.
proof.
rewrite eq_matrixP.
split => [/# | i j bound].
rewrite get_subm; first 2 smt(size_subm).
rewrite get_catmr //= (getm0E m1) /= 2:add0r; smt(cols_subm).
qed.
lemma subm_catmcCr m1 m2:
subm (m1 / m2) (rows m1) (rows m1 + rows m2) 0 (cols m2) = m2.
proof. apply trmx_inj => /=. exact subm_catmrCr. qed.
lemma rowmx_row_eq_subm r m: rowmx (row m r) = subm m r (r+1) 0 (cols m).
proof.
rewrite eq_matrixP /=.
split => [/# | i j bound].
have ->: i = 0 by smt().
rewrite get_rowmx get_subm 3:get_row /#.
qed.
lemma colmx_col_eq_subm c m: colmx (col m c) = subm m 0 (rows m) c (c+1).
proof. apply trmx_inj => /=. rewrite -row_trmx. apply rowmx_row_eq_subm. qed.
lemma catmr_subm m n: 0 <= n < cols m =>
(subm m 0 (rows m) 0 n || subm m 0 (rows m) n (cols m)) = m.
proof.
move => n_bound; rewrite eq_matrixP /=.
split => [| i j bound].
- smt(rows_catmr cols_catmr size_subm).
rewrite get_catmr // cols_subm /=.
case (j < n) => j_bound.
- rewrite get_subm /=; first 2 smt(size_catmr size_subm).
rewrite (getm0E (subm _ _ _ _ _)).
+ smt(size_catmr size_subm).
by rewrite addr0.
- rewrite getm0E; 1: smt(size_catmr size_subm).
rewrite add0r get_subm; smt(size_catmr size_subm).
qed.
lemma subm_colmx (m: matrix) l :
0 <= l => cols m = l + 1 =>
(subm m 0 (rows m) 0 l || colmx (col m l)) = m.
proof. by move => l_ge0 c_m; rewrite colmx_col_eq_subm -c_m catmr_subm /#. qed.
lemma catmc_subm m n: 0 <= n < rows m =>
subm m 0 n 0 (cols m) / subm m n (rows m) 0 (cols m) = m.
proof. move => n_bound; apply trmx_inj => /=. by apply catmr_subm. qed.
(* Updating one entry of a matrix *)
op updm (m: matrix) (r c: int) (p: t) = offunm
(fun (i j: int) => if i = r /\ j = c then p else m.[i,j], rows m, cols m).
lemma rows_updm m r c p: rows (updm m r c p) = rows m by done.
lemma cols_updm m r c p: cols (updm m r c p) = cols m by done.
lemma size_updm m r c p: size (updm m r c p) = size m by done.
lemma get_updm (m: matrix) (r c i j: int) (p: t): mrange m r c => mrange m i j =>
(updm m r c p).[i, j] = if (i, j) = (r, c) then p else m.[i, j].
proof.
rewrite /updm.
case (i = r) => /= [-> | neq_r bound1 bound2].
- case (j = c) => /= [-> bound _ | neq_col bound1 bound2].
+ by rewrite get_offunm //= get_offunm //= addrC subrK.
+ rewrite get_offunm /#.
- by rewrite get_offunm //= offunm0E.
qed.
(* Multplication with scalar for matrices and vectors *)
op scalarm (s: t) (m: matrix) = diagc (rows m) s * m.
abbrev ( *** ) s m = scalarm s m.
lemma scalarmE s m: s *** m = diagc (rows m) s * m by done.
lemma rows_scalarm (m: matrix) (s: t): rows (s *** m) = rows m.
proof. by rewrite rows_mulmx. qed.
lemma cols_scalarm (m: matrix) (s: t): cols (s *** m) = cols m.
proof. by rewrite cols_mulmx. qed.
lemma scalarm_mrange m s i j: mrange (s *** m) i j = mrange m i j.
proof. by rewrite rows_mulmx // cols_mulmx. qed.
lemma size_scalarm (m: matrix) (s: t): size (s *** m) = size m.
proof. by rewrite rows_scalarm cols_scalarm. qed.
lemma get_scalarm (m: matrix) (s: t) (i j: int):
(s *** m).[i, j] = m.[i, j] * s.
proof.
rewrite get_mulmx /dotp 2!size_offunv /=.
case (0 <= i < rows m) => bound.
- rewrite (bigD1 _ _ i) 1:mem_range 1:/# 1:range_uniq.
rewrite big1 /= 2:get_vectc 2:/# 2:addr0 2:mulrC 2:// => k.
rewrite /predC1 eq_sym => -> /=.
apply mul0r.
- rewrite big1 => [k _ /=|].
+ by rewrite getv0E // mul0r.
by rewrite getm0E 1:/# mul0r.
qed.
lemma scalar0m (m: matrix): zeror *** m = zerom (rows m) (cols m).
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: rewrite size_scalarm /#.
by rewrite get_zerom get_scalarm mulr0.
qed.
lemma scalar1m (m: matrix): oner *** m = m.
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: by rewrite size_scalarm.
by rewrite get_scalarm mulr1.
qed.
lemma scalarNm (m: matrix) (s: t): (- s) *** m = - (s *** m).
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: rewrite !size_scalarm /=; 1: smt(size_scalarm).
by rewrite /= !get_scalarm /= mulrN.
qed.
lemma scalarmAs s1 s2 m: s1 *** (s2 *** m) = (s1 * s2) *** m.
proof.
apply eq_matrixP.
split => [| i j _]; 1: by rewrite !size_scalarm.
by rewrite 3!get_scalarm -mulrA [s2 * s1]mulrC.
qed.
lemma scalarmD m1 m2 s: s *** (m1 + m2) = s *** m1 + s *** m2.
proof.
apply eq_matrixP.
rewrite rows_scalarm rows_addm cols_scalarm rows_addm.
rewrite 2!cols_addm /= => i j bound.
by rewrite get_addm get_scalarm get_addm !get_scalarm mulrDl.
qed.
lemma scalarDm m (s1 s2: t): (s1 + s2) *** m = s1 *** m + s2 *** m.
proof.
apply eq_matrixP.
rewrite rows_scalarm cols_scalarm.
split => [| i j bound]; 1: rewrite size_addm !size_scalarm //.
by rewrite get_addm 3!get_scalarm // mulrDr.
qed.
(* dmatrix helper function: construct matrix from list of (column) vectors *)
op ofcols r c (vs : vector list) =
offunm (fun (i j : int) => (nth witness vs j).[i], r, c).
lemma ofcols_cat k l1 l2 (vs1 vs2 : vector list) :
size vs1 = l1 => size vs2 = l2 =>
ofcols k (l1 + l2) (vs1 ++ vs2) = (ofcols k l1 vs1 || ofcols k l2 vs2).
proof.
move => s_vs1 s_vs2.
apply eq_matrixP; rewrite /ofcols size_catmr //=.
rewrite -s_vs1 -s_vs2.
split => [|i j [i_bound j_bound]]; 1:smt(List.size_ge0).
rewrite get_catmr //= ler_maxr; 1: smt(List.size_ge0).
rewrite get_offunm /= ?nth_cat; 1: smt(List.size_ge0).
case (j < size vs1) => j_vs1.
- rewrite get_offunm. smt(List.size_ge0). rewrite offunm0E. smt(List.size_ge0).
by rewrite addr0.
- rewrite offunm0E. smt(List.size_ge0). rewrite get_offunm. smt(List.size_ge0).
by rewrite /= s_vs1 add0r.
qed.
(* Distribution of matrices sampled using d element-wise *)
op dmatrix (d : R distr) (r c: int) =
dmap (dlist (dvector d r) (max 0 c)) (ofcols r c).
lemma size_dmatrix d r c m: m \in dmatrix d r c =>
0 <= r => 0 <= c => size m = (r, c).
proof. rewrite supp_dmap /ofcols => -[l [H0 -> /=]] /#. qed.
lemma dmatrix0E (d: R distr) (m: matrix) r c:
size m <> (r, c) => 0 <= r => 0 <= c => mu1 (dmatrix d r c) m = 0%r.
proof.
move => size_ineq r_ge0 c_ge0; apply supportPn.
case (m \in dmatrix d r c) => [cont | //].
apply size_ineq.
apply size_dmatrix in cont => /#.
qed.
lemma dmatrix1E d m : mu1 (dmatrix d (rows m) (cols m)) m =
BRM.bigi predT (fun i =>
BRM.bigi predT (fun j => mu1 d m.[i, j]) 0 (cols m)) 0 (rows m).
proof.
pose g (m: matrix) := mkseq (fun i => col m i) (cols m).
rewrite (in_dmap1E_can _ _ g) /ofcols.
- rewrite /g eq_matrixP /= => i j bound.
by rewrite get_offunm /= 1:/# nth_mkseq 1:/# get_col.
- rewrite /g => y y_in <- /=.
rewrite supp_dlist 1:/# in y_in.
apply (eq_from_nth witness) => [| i i_bound]; 1: rewrite size_mkseq /#.
rewrite nth_mkseq /= 1:/# eq_vectorP size_col rows_offunm.
rewrite -(all_nthP _ _ witness) in y_in.
elim y_in => size_y.
move => cont; have i_bound': 0 <= i && i < size y by done.
apply cont in i_bound'; clear cont; move: i_bound'.
move => nth_y_size; apply size_dvector in nth_y_size.
split => [/# | j j_bound].
rewrite get_col get_offunm /= /#.
- rewrite dlist1E 1:/# size_mkseq /= (BRM.big_nth witness) predTofV.
rewrite BRM.exchange_big /(\o) /= 2!BRM.big_seq /= size_mkseq max0cols.
apply BRM.eq_big => i // i_bound /=.
rewrite mem_range in i_bound.
have ->: rows m = size (nth witness (g m) i) by rewrite nth_mkseq //.
rewrite dvector1E.
congr; apply fun_ext => j.
by rewrite nth_mkseq.
qed.
lemma dmatrix_uni d r c: is_uniform d => is_uniform (dmatrix d r c).
proof.
move=> uni_d; apply/dmap_uni_in_inj/dlist_uni/dvector_uni => //.
move=> xs ys xsin ysin /= eq_offunm.
apply/(eq_from_nth witness); 1: smt(supp_dlist).
move => i i_bound; rewrite eq_vectorP.
rewrite supp_dlist 1:/# -(all_nthP _ _ witness) in xsin.
rewrite supp_dlist 1:/# -(all_nthP _ _ witness) in ysin.
move: xsin ysin => [size_xs in_size_xs] [size_ys in_size_ys].
have ->: size (nth witness xs i) = max 0 r.
- by apply/(size_dvector d)/in_size_xs.
have -> /=: size (nth witness ys i) = max 0 r.
- by apply/(size_dvector d)/in_size_ys => /#.
move => j j_bound; apply (congr1 (fun m => m.[j, i])) in eq_offunm.
rewrite /= get_offunm /= 1:/# in eq_offunm.
rewrite get_offunm /= /# in eq_offunm.
qed.
lemma dmatrix_ll d r c: is_lossless d => is_lossless (dmatrix d r c).
proof. by move=> ll_d; apply/dmap_ll/dlist_ll/dvector_ll. qed.
lemma mu1_dmatrix_fu (d: R distr) (m: matrix): is_funiform d =>
mu1 (dmatrix d (rows m) (cols m)) m = (mu1 d witness)^((rows m)*(cols m)).
proof.
move => d_funi.
rewrite dmatrix1E.
have ->: (fun (i : int) => (BRM.bigi predT
(fun (j : int) => mu1 d m.[i, j]) 0 (cols m))) =
(fun (_: int) => (mu1 d witness)^(cols m)).
- apply fun_ext => i.
have ->: (fun (j : int) => mu1 d m.[i, j]) = fun (_: int) => mu1 d witness.
+ apply fun_ext => j.
exact d_funi.
have: 0 <= cols m by exact cols_ge0.
move: (cols m).
elim/ge0ind => [/# | _ | n bound IH _].
+ by rewrite range_geq //= BRM.big_nil RField.expr0.
+ by rewrite BRM.big_int_recr //= RField.exprS // RField.mulrC IH.
- have: 0 <= rows m by exact rows_ge0.
move: (rows m).
elim/ge0ind => [/# | _ | n bound IH _].
+ by rewrite range_geq //= BRM.big_nil RField.expr0.
+ rewrite BRM.big_int_recr // RField.exprM RField.exprS // RField.mulrC.
by rewrite RField.exprMn 1:/# /= IH // RField.exprM.
qed.
lemma dmatrix1r d k : 0 <= k =>
dmatrix d k 1 = dmap (dvector d k) colmx.
proof.
move => k_ge0; rewrite /dmatrix /ofcols ler_maxr // dlist1 dmap_comp /(\o).
apply eq_dmap_in => v /size_dvector; rewrite ler_maxr //= => s_v.
rewrite /colmx s_v; apply/eq_matrixP => /= i j [i_bound j_bound].
by rewrite /ofcols !get_offunm /#.
qed.
lemma supp_dmatrix d m r c :
0 <= r => 0 <= c =>
(m \in dmatrix d r c) <=>
size m = (r,c) /\ forall i j, mrange m i j => m.[i,j] \in d.
proof.
move => r_ge0 c_ge0; split => [m_supp|]; last first.
- case => -[r_m c_m] m_d; rewrite /support -r_m -c_m dmatrix1E.
apply prodr_gt0_seq => i i_row _ /=.
by apply prodr_gt0_seq => j j_col _ /=; apply m_d; smt(mem_iota).
have [r_m c_m] : size m = (r,c) by smt(size_dmatrix).
split => [//|i j range_ij]; move: m_supp.
rewrite -r_m -c_m /support dmatrix1E => gt0_big.
pose G i0 := (fun (j0 : int) => mu1 d m.[i0, j0]).
pose F := fun i0 : int => BRM.bigi predT (G i0) 0 (cols m).
have /(_ i _ _) := gt0_prodr_seq predT F (range 0 (rows m)) _ gt0_big => //.
- by move => a _ _; apply prodr_ge0.
- smt(mem_iota).
move => gt0_F.
have /(_ j _ _) := gt0_prodr_seq predT (G i) (range 0 (cols m)) _ gt0_F => //.
smt(mem_iota).
qed.
lemma dmatrix_add_r d k l1 l2 : 0 <= k => 0 <= l1 => 0 <= l2 =>
dmatrix d k (l1 + l2) =
dmap (dmatrix d k l1 `*` dmatrix d k l2)
(fun mv : matrix * matrix => mv.`1 || mv.`2).
proof.
move => k_ge0 l1_ge0 l2_ge0; rewrite {1}/dmatrix ler_maxr 1:/# dlist_add //.
pose catp (p : vector list * vector list) := p.`1 ++ p.`2.
pose catm (m : matrix * matrix) := m.`1 || m.`2.
pose stichp (p : vector list * vector list) :=
(ofcols k l1 p.`1,ofcols k l2 p.`2).
rewrite dmap_comp -/catp -/catm (eq_dmap_in _ _ (catm \o stichp)).
- case => vs1 vs2 /supp_dprod /= ?.
have [? ?] : size vs1 = l1 /\ size vs2 = l2 by smt(supp_dlist).
exact ofcols_cat.
rewrite -dmap_comp -(dmap_dprod _ _ (ofcols k l1) (ofcols k l2)).
have {1}-> : l1 = max 0 l1 by smt().
have {1}-> : l2 = max 0 l2 by smt().
done.
qed.
lemma dmatrixSrr d k l : 0 <= k => 0 <= l =>
dmatrix d k (l + 1) =
dmap (dmatrix d k l `*` dvector d k)
(fun mv : matrix * vector => mv.`1 || colmx mv.`2).
proof.
move => k_ge0 l_ge0; rewrite dmatrix_add_r // dmatrix1r //.
by rewrite dmap_dprodR dmap_comp.
qed.
lemma dmatrixRSr1E d (m : matrix) k l :
0 <= k => 0 <= l => size m = (k,l+1) =>
mu1 (dmatrix d k (l + 1)) m =
mu1 (dmatrix d k l `*` dvector d k) (subm m 0 k 0 l, col m l).
proof.
move => k_ge0 l_ge0 [r_m c_m]; rewrite dmatrixSrr // dmapE /(\o) /pred1.
apply mu_eq_support => -[m2 v] /supp_dprod /=.
case => /size_dmatrix /(_ _ _) //= [r_m2 c_m2] /size_dvector s_v.
apply/eq_iff; split => [<-|[-> ->]]; last by rewrite -r_m subm_colmx.
rewrite -r_m2 -c_m2 subm_catmrCl /=.
by rewrite col_catmrR //= /#.
qed.
lemma supp_dmatrix_full m d r c :
0 <= r => 0 <= c =>
is_full d => m \in dmatrix d r c <=> size m = (r,c).
proof. smt(supp_dmatrix). qed.
lemma dvector_rnd_funi (d : R distr) (v1 v2 : vector) l :
is_funiform d => size v1 = size v2 =>
mu1 (dvector d l) v1 = mu1 (dvector d l) v2.
proof.
move=> d_funi s_v1_v2; case (l = size v1) => [->|?].
- by rewrite mu1_dvector_fu // s_v1_v2 mu1_dvector_fu.
smt(get_dvector0E emptyv_unique).
qed.
lemma supp_dmatrix_catmr d (m1 m2 : matrix) r c1 c2 :
0 <= r => 0 <= c1 => 0 <= c2 =>
m1 \in dmatrix d r c1 => m2 \in dmatrix d r c2 =>
(m1 || m2) \in dmatrix d r (c1 + c2).
proof.
move => r_ge0 c1_ge0 c2_ge0 m1_dm m2_dm; rewrite dmatrix_add_r //.
rewrite (dmap_dprod_comp _ _ (fun x => x) (fun x => x) catmr).
rewrite /support !dmap_id dmapE /(\o) /pred1.
apply (StdOrder.RealOrder.ltr_le_trans
(mu1 (dmatrix d r c1 `*` dmatrix d r c2) (m1,m2))).
- rewrite witness_support.
exists (m1, m2).
by rewrite supp_dprod /= m1_dm m2_dm.
- by apply mu_le => /#.
qed.
end Matrices.
export Matrices.