https://github.com/EasyCrypt/easycrypt
Raw File
Tip revision: 30bfa950afa3806948c073d3c9ec4468d33ea940 authored by Pierre-Yves Strub on 11 December 2023, 10:58:49 UTC
New tactic: "proc change"
Tip revision: 30bfa95
DynMatrix.eca
(* This theory contains a formalization of vectors and matrices with          *)
(* arbitrary and dynamic size.                                                *)

(* Many operations conventionally have requirements on the sizes of their     *)
(* operands. Matrix addition, for example, only makes sense if the matrices   *)
(* have the same size. The user is responsible for making sure that the       *)
(* parameters are valid. Many operators return the empty matrix/vector on     *)
(* malformed input, but not all, so checking the output is not a reliable way *)
(* to check that the parameters are valid.                                    *)

require import AllCore List Distr DBool DList.
import StdBigop.Bigreal StdOrder.IntOrder.

(* -------------------------------------------------------------------------- *)
require (*--*) Quotient StdOrder Bigalg.

(* -------------------------------------------------------------------------- *)
clone import Ring.ComRing as ZR.

type R = t.

(* -------------------------------------------------------------------------- *)
clone import Bigalg.BigComRing as Big with theory CR <- ZR proof *.
import BAdd.

(* -------------------------------------------------------------------------- *)
(* TOTHINK: This lemma is very useful in mulrxA, other uses? *)
lemma big_range0r (j' j i : int) P F : j' <= j =>
  (forall x, j' <= x < j => P x => F x = zeror) =>
  bigi P F i j = bigi P F i j'.
proof.
move => le_j'_j F0E; case (i <= j') => [le_i_j'|gt_i_j'].
- rewrite (big_cat_int j') 1,2:/# [bigi P F j' j]big1_seq ?addr0 //. 
  smt(mem_range).
rewrite !big1_seq; smt(mem_range).
qed.

theory Vectors.

type prevector = (int -> R) * int.

op vclamp (pv: prevector): prevector =
  ((fun i => if 0 <= i < pv.`2 then pv.`1 i else zeror), max 0 pv.`2).

lemma nosmt vclamp_idemp pv: vclamp (vclamp pv) = vclamp pv.
proof. rewrite /vclamp /#. qed.

op eqv (pv1 pv2: prevector) =
  vclamp pv1 = vclamp pv2.

lemma nosmt eqv_vclamp pv: eqv pv (vclamp pv). 
proof. by rewrite /eqv vclamp_idemp. qed.

clone import Quotient.EquivQuotient as QuotientVec with
  type T <- prevector,
  op eqv <- eqv
  rename [type] "qT" as "vector"
  proof * by smt().

type vector = QuotientVec.vector.

op tofunv v = vclamp (repr v).

op offunv pv = pi (vclamp pv).

lemma tofunvK: cancel tofunv offunv.
proof. 
rewrite /tofunv /offunv /cancel => v; rewrite vclamp_idemp.
by rewrite -{2}[v]reprK -eqv_pi /eqv vclamp_idemp. 
qed.

lemma offunvK pv: tofunv (offunv pv) = vclamp pv.
proof. by rewrite /tofunv /offunv eqv_repr vclamp_idemp. qed.

lemma vectorW (P : vector -> bool): 
  (forall pv, P (offunv pv)) => forall v, P v by smt(tofunvK).

(* Dimension of the vector *)
op size v = (tofunv v).`2.

lemma nosmt size_ge0 v: 0 <= size v by smt().

lemma nosmt max0size v: max 0 (size v) = size v by smt().

hint simplify max0size.

lemma offunv_max f n: offunv (f, max 0 n) = offunv (f, n).
proof. rewrite /offunv -eqv_pi /eqv /vclamp /#. qed.

lemma size_offunv f n: size (offunv (f, n)) = max 0 n.
proof. by rewrite /size offunvK /#. qed.

hint simplify size_offunv.

(* Getting the i-th element of a vector *)
op get (v: vector) (i: int) = (tofunv v).`1 i.

abbrev "_.[_]" (v: vector) (i : int) = get v i.

lemma get_offunv f n (i : int) : 0 <= i < n =>
  (offunv (f, n)).[i] = f i.
proof. by rewrite /get /= offunvK /vclamp /= => ->. qed.

lemma getv0E v i: !(0 <= i < size v) => v.[i] = zeror by smt().

lemma offunv0E f n (i : int) : !(0 <= i < n) =>
  (offunv (f, n)).[i] = zeror.
proof. move => ioutn. rewrite getv0E /= /#. qed.

lemma eq_vectorP (v1 v2 : vector) : (v1 = v2) <=> 
  (size v1 = size v2 /\ forall i, 0 <= i < size v1 => v1.[i] = v2.[i]).
proof. 
split => [->//|[eq_size eq_vi]].
have: tofunv v1 = tofunv v2 by rewrite /tofunv /vclamp /#.
smt(tofunvK). 
qed.

(* Constant valued vector of dimension n *)
op vectc n c = offunv ((fun _ => c), n).

lemma size_vectc n c: size (vectc n c) = max 0 n by done.

hint simplify size_vectc.

lemma get_vectc n c i: 0 <= i < n => (vectc n c).[i] = c by smt(get_offunv).

(* Zero-vector of dimension n *)
op zerov n = vectc n zeror.

lemma size_zerov n: size (zerov n) = max 0 n.
proof. by rewrite size_offunv. qed.

lemma get_zerov n i: (zerov n).[i] = zeror.
proof.
case (0 <= i < n) => i_bound; first by rewrite get_offunv. 
by rewrite getv0E ?size_zerov /#.
qed.

hint simplify size_zerov, get_zerov.

(* The unique 0-size vector *)
op emptyv = zerov 0.

lemma size_emptyv: size emptyv = 0 by done.

lemma get_emptyv i: emptyv.[i] = zeror by done.

hint simplify size_emptyv, get_emptyv.

lemma emptyv_unique v: size v = 0 => v = emptyv.
proof. move => size_eq. apply eq_vectorP => /#. qed.

lemma size_vectc2 n c1 c2: vectc (size (vectc n c1)) c2 = vectc n c2.
proof. case (0 < n) => [/# | bd]. rewrite 2!(emptyv_unique (vectc _ _)) /#. qed.

lemma size_zerov2 n: zerov (size (zerov n)) = zerov n.
proof. rewrite /zerov. apply size_vectc2. qed.


(* lifting functions to vectors *)
op mapv (f : R -> R) (v : vector) : vector = 
  offunv (fun i => f v.[i], size v).

lemma size_mapv f v : size (mapv f v) = size v by [].

lemma get_mapv f (v : vector) i : 
  (0 <= i < size v) => (mapv f v).[i] = f v.[i].
proof. by move => ?; rewrite get_offunv. qed.

(* Vector addition *)
op[opaque] (+) (v1 v2 : vector) = 
  offunv ((fun i => v1.[i] + v2.[i]), max (size v1) (size v2)).

lemma size_addv v1 v2: size (v1 + v2) = max (size v1) (size v2).
proof. rewrite /(+) size_offunv /#. qed.

lemma get_addv (v1 v2 : vector) i: (v1 + v2).[i] = v1.[i] + v2.[i].
proof.
case (0 <= i < max (size v1) (size v2)) => bound. 
- by rewrite /(+) /= get_offunv.
- rewrite !getv0E ?size_addv ?size_eq 4:addr0 /#.
qed.

(* Additive inverse of vector *)
op[opaque] [-] (v: vector) = offunv ((fun i => -v.[i]), size v).

lemma nosmt size_oppv v: size (-v) = size v by rewrite /([-]).

lemma getvN (v : vector) i : (-v).[i] = - v.[i].
proof.
case: (0 <= i < size v) => bound; 1: by rewrite /([-]) get_offunv. 
by rewrite !getv0E /([-]) //= oppr0.
qed.

hint simplify getvN, size_oppv.

(* Simplifications for some special vectors *)
lemma oppv_zerov n: -(zerov n) = zerov n.
proof. apply eq_vectorP => /= i bound. by rewrite oppr0. qed.

lemma oppv_emptyv: -emptyv = emptyv by apply oppv_zerov.

hint simplify oppv_zerov, oppv_emptyv.

(* Module-like properties of vectors *)
lemma addvA: associative Vectors.(+).
proof.
move => v1 v2 v3; apply eq_vectorP.
rewrite !size_addv.
split => [/# | i _].
by rewrite !get_addv addrA.
qed.

lemma addvC: commutative Vectors.(+).
proof.
move => v1 v2; apply eq_vectorP.
rewrite !size_addv.
split => [/# | i _].
by rewrite !get_addv addrC.
qed.

lemma add0v v: zerov (size v) + v = v.
proof.
rewrite eq_vectorP size_addv /= => i bound. 
by rewrite get_addv //= add0r.
qed.

hint simplify add0v.

lemma lin_add0v v n: n = size v => zerov n + v = v by done.

lemma addv0 v: v + zerov (size v) = v by rewrite addvC.

lemma lin_addv0 v n: n = size v => v + zerov n = v by move => ->; rewrite addv0.

abbrev (-) v1 v2 = v1 + (-v2).

lemma addvN v: v - v = zerov (size v).
proof. 
rewrite eq_vectorP size_addv /= => i bound. 
by rewrite get_addv //= addrN. 
qed.

hint simplify addv0.

lemma oppvD (v1 v2: vector): -(v1 + v2) = -v1 + -v2.
proof.
apply eq_vectorP.
split => [| i _].
- by rewrite /= 2!size_addv /=.
by rewrite /= 2!get_addv opprD.
qed.

lemma oppvK (v: vector): - (- v) = v.
proof.
rewrite eq_vectorP. 
split => [|i bound]; 1: by rewrite 2!size_oppv.
by rewrite 2!getvN opprK.
qed.

lemma sub_eqv (x y z : vector) : 
  size x = size z => size y = size z => 
  x - z = y <=> x = y + z.
proof. smt(addvA addvC addvN addv0). qed.

(* Inner product *)
op[opaque] dotp (v1 v2 : vector) =
  bigi predT (fun i => v1.[i] * v2.[i]) 0 (max (size v1) (size v2)).

lemma dotpE v1 v2: dotp v1 v2 = 
  bigi predT (fun i => v1.[i] * v2.[i]) 0 (max (size v1) (size v2)).
proof. by rewrite /dotp. qed.

lemma dotpC : commutative dotp.
proof. 
move => v1 v2; rewrite 2!dotpE maxrC. 
apply eq_bigr => i _ /=.
by rewrite mulrC.
qed.

lemma dotpNv v1 v2: dotp (-v1) v2 = - dotp v1 v2.
proof.
rewrite 2!dotpE /= sumrN.
apply eq_bigr => i _ /=.
by rewrite mulNr.
qed.

lemma dotpvN v1 v2: dotp v1 (-v2) = - dotp v1 v2 by rewrite dotpC dotpNv dotpC.

hint simplify dotpNv, dotpvN.

lemma dotpDr v1 v2 v3: dotp v1 (v2 + v3) = dotp v1 v2 + dotp v1 v3.
proof.
rewrite !dotpE size_addv; pose m := max _ (max _ _).
(* FIXME: (max _ _) should be precise enough, but this matches m! *)
rewrite -[bigi _ _ 0 (max _ (size v2))](big_range0r _ m); 1,2: smt(getv0E mul0r).
rewrite -[bigi _ _ 0 (max _ (size v3))](big_range0r _ m); 1,2: smt(getv0E mul0r).
by rewrite sumrD /= &(eq_bigr) /= => i _; rewrite get_addv mulrDr.
qed.

lemma dotpDl v1 v2 v3: dotp (v1 + v2) v3 = dotp v1 v3 + dotp v2 v3.
proof. by rewrite dotpC (dotpC v1) (dotpC v2) dotpDr. qed.

lemma dotp0v v n: dotp (zerov n) v = zeror.
proof.
rewrite dotpE big1 // => i _ /=.
by rewrite mul0r.
qed.

hint simplify dotp0v.

lemma dotpv0 v n: dotp v (zerov n) = zeror by rewrite dotpC.

hint simplify dotpv0.

(* Vector concatenation *)
op catv (v1 v2: vector) = 
  offunv ((fun i => v1.[i] + v2.[i-size v1]), size v1 + size v2).

abbrev ( || ) v1 v2 = catv v1 v2.

lemma size_catv (v1 v2: vector): size (v1 || v2) = (size v1 + size v2).
proof. rewrite /catv /= /#. qed.

lemma get_catv (v1 v2: vector) i : 
  (v1 || v2).[i] = if i < size v1 then v1.[i] else v2.[i - size v1].
proof.
case (0 <= i < size v1 + size v2) => range; last first.
- rewrite !getv0E ?size_catv //; smt(size_ge0).
rewrite get_offunv //=; case (i < size v1) => ?.
- by rewrite [v2.[_]]getv0E //= 1:/# addr0.
- by rewrite [v1.[_]]getv0E //= 1:/# add0r.
qed.

lemma get_catv' (v1 v2: vector) i: (v1 || v2).[i] = v1.[i] + v2.[i-size v1].
proof.
rewrite /catv.
case (0 <= i < size v1 + size v2) => range.
- rewrite get_offunv //=.
- rewrite !getv0E 4:addr0 /#.
qed.

lemma get_catv_l (v1 v2: vector) i : 
  i <  size v1 => (v1 || v2).[i] = v1.[i].
proof. smt(get_catv). qed. 

lemma get_catv_r (v1 v2: vector) i : 
  size v1 <= i => (v1 || v2).[i] = v2.[i - size v1].
proof. smt(get_catv). qed.

lemma catvA : associative Vectors.(||).
proof.
move => v1 v2 v3. rewrite eq_vectorP !size_catv addrA /= => i bound. 
by rewrite !get_catv' size_catv opprD !addrA.
qed.

lemma dotp_catv v1 v2 v3 v4: size v1 = size v3 =>
  dotp (v1 || v2) (v3 || v4) = (dotp v1 v3) + (dotp v2 v4).
proof.
move => size_eq; rewrite !dotpE !size_catv size_eq /=.
rewrite (range_cat (size v3)) 1:/# 1:/# big_cat; congr.
- by apply eq_big_seq => i /mem_range ? /=; smt(get_catv_l).
have ->:max (size v3 + size v2) (size v3 + size v4) =
  size v3 + max (size v2) (size v4) by smt().
have {1}->: size v3 = 0 + size v3 by smt().
rewrite big_addn.
rewrite addrC addrA [_ + size v3]Ring.IntID.addrC subrr.
apply eq_big_seq => i /mem_range [l_bound u_bound] /=.
by rewrite !get_catv size_eq gtr_addr ler_gtF //= -addrA subrr.
qed.

lemma addv_catv (v1 v2 v3 v4: vector): 
  size v1 = size v3 =>
  (v1 || v2) + (v3 || v4) = ((v1 + v3) || (v2 + v4)).
proof.
move => size_eq; apply eq_vectorP.
rewrite size_addv 3!size_catv 2!size_addv. 
split => [/# | i bound].
rewrite get_addv 3!get_catv' 2!get_addv size_addv size_eq /max /=.
rewrite 2!addrA. 
congr.
rewrite -2!addrA. 
congr.
by rewrite addrC.
qed.

lemma oppv_catv (v1 v2: vector): - (v1 || v2) = (-v1 || -v2).
proof. by apply eq_vectorP => i bound /=; rewrite 2!get_catv' opprD. qed.

(* Splitting apart vectors *)
op subv (v: vector) (n m: int) = offunv ((fun i => v.[i+n]), m - n).

lemma size_subv (v: vector) (n m: int): size (subv v n m) = max 0 (m - n).
proof. by done. qed.

hint simplify size_subv.

lemma get_subv (v: vector) (n m i: int): 0 <= i < m - n => 
  (subv v n m).[i] = v.[i+n].
proof. move => i_range; by rewrite /subv get_offunv. qed.

(* Interactions between subv and concatenation *)
lemma subv_catvCl v1 v2: subv (v1 || v2) 0 (size v1) = v1.
proof.
rewrite eq_vectorP size_subv /= => i bound.
by rewrite get_subv // get_catv' /= (getv0E v2) 1:/# addr0.
qed.

lemma subv_catvCr v1 v2: subv (v1 || v2) (size v1) (size v1 + size v2) = v2.
proof.
rewrite eq_vectorP size_subv. 
split => [/# | i bound].
rewrite get_subv 1:/# get_catv' (getv0E v1) 1:/# add0r /#.
qed.

lemma catv_subvC v i: 0 <= i <= size v => 
  (subv v 0 i || subv v i (size v)) = v.
proof.
move => i_bound; rewrite eq_vectorP size_catv !size_subv /= -andaE. 
split => [/#|-> k k_bound]; case (k < i) => lt_i_k.
- rewrite get_catv_l; smt(get_subv).
- rewrite get_catv_r; smt(get_subv).
qed.
 
lemma subv_addv v1 v2 a b: size v1 = size v2 => 
    subv (v1 + v2) a b  = subv v1 a b + subv v2 a b.
proof.
move => size_eq; rewrite eq_vectorP ?size_addv /= => /= i bound.
rewrite get_subv 1:/# !get_addv // !get_subv /#.
qed.

(* Updating one entry in a vector *)
op updv (v: vector) (n: int) (m: t) =
  offunv (fun i => if i = n then m else v.[i], size v).

lemma size_updv (v: vector) (n: int) (m: t): size (updv v n m) = size v by done.

lemma get_updv (v: vector) (n i: int) (m: t): (0 <= i < size v) =>
    (updv v n m).[i] = if i = n then m else v.[i].
proof. rewrite /updv => i_bound. by rewrite get_offunv. qed.

(* Multiplying vector by a scalar *)
op scalarv (s: t) (v: vector) = offunv(fun i => s * v.[i], size v).

abbrev ( ** ) s v = scalarv s v.

lemma size_scalarv (v: vector) (s: t): size (s ** v) = size v by done.

lemma get_scalarv (v: vector) (s: t) (i: int): (s ** v).[i] = s * v.[i].
proof.
case (0 <= i < size v) => bound; 1: by rewrite get_offunv.
by rewrite offunv0E // getv0E // mulr0.
qed.

lemma scalar0v (v: vector): zeror ** v = zerov (size v).
proof.
rewrite eq_vectorP. 
split => [| i bound]; 1: rewrite size_scalarv /#.
by rewrite get_zerov get_scalarv mul0r.
qed.

lemma scalar1v (v: vector): oner ** v = v.
proof.
rewrite eq_vectorP.
split => [| i bound]; 1: by rewrite size_scalarv.    
by rewrite get_scalarv mul1r.
qed.

lemma scalarvN a (v : vector): a ** -v = - (a ** v).
proof. 
rewrite eq_vectorP size_oppv 2!size_scalarv size_oppv /= => i bound.
by rewrite 2!get_scalarv getvN mulrN.
qed.

lemma scalarNv (a : R) (v : vector): (-a) ** v = - (a ** v).
proof.
rewrite eq_vectorP size_oppv 2!size_scalarv /= => i bound.
by rewrite !get_scalarv mulNr.
qed.

lemma dotp_scalarv (v1 v2: vector) (s1 s2: t):
  dotp (s1 ** v1) (s2 ** v2) = (s1 * s2) * (dotp v1 v2).
proof.
rewrite 2!dotpE 2!size_scalarv !mulr_sumr &(eq_bigr) /= => i _.
rewrite 2!get_scalarv; smt(mulrA mulrC).
qed.

lemma dotp_scalarv_l (v1 v2: vector) (s : t):
  dotp (s ** v1) v2 = s * dotp v1 v2.
proof. by rewrite -[v2]scalar1v dotp_scalarv scalar1v mulr1. qed.

lemma dotp_scalarv_r (v1 v2: vector) (s: t):
  dotp v1 (s ** v2) = s * dotp v1 v2.
proof. by rewrite -[v1]scalar1v dotp_scalarv scalar1v mul1r. qed.

lemma scalarvDr s (v1 v2 : vector) :
  s ** (v1 + v2) = s ** v1 + s ** v2.
proof.
apply eq_vectorP; rewrite 5!(size_addv,size_scalarv) /= => i Hi.
by rewrite !(get_scalarv,get_addv) mulrDr.
qed.

(* List-vector isomorphism *)
op oflist (s : R list): vector = offunv (nth witness s, size s).

lemma size_oflist l: size (oflist l) = size l.
proof. rewrite size_offunv; 1: smt(List.size_ge0). qed.

lemma get_oflist w (l: R list) i: 0 <= i < size l => (oflist l).[i] = nth w l i.
proof. by move => bound; rewrite get_offunv // (nth_change_dfl w). qed.

op tolist (v : vector): R list = map (fun i => v.[i]) (range 0 (size v)).

lemma size_tolist v: size (tolist v) = size v.
proof. rewrite size_map size_range /#. qed.

lemma nth_tolist w v i: 0 <= i < size v => nth w (tolist v) i = v.[i].
proof.
move => bound; rewrite (nth_map witness w) 1:size_range; 1: smt(List.size_ge0).
by rewrite /= nth_range.
qed.

lemma mem_tolist x v : x \in tolist v <=> exists i, 0 <= i < size v /\ x = v.[i].
proof. by rewrite mapP; smt(mem_range). qed.

lemma oflistK: cancel oflist tolist.
proof.
move => l; apply (eq_from_nth witness); rewrite size_tolist size_oflist //.
by move => i bound; rewrite nth_tolist 1:size_oflist // (get_oflist witness).
qed.

lemma tolistK: cancel tolist oflist.
proof.
move => l; apply eq_vectorP.
rewrite size_oflist size_tolist /= => i bound.
rewrite (get_oflist witness) 1:size_tolist // nth_tolist //.
qed.

lemma oflist_inj: injective oflist by smt(oflistK).

lemma tolist_inj: injective tolist by smt(tolistK).

lemma tolist_catv (v1 v2 : vector) : 
  tolist (v1 || v2) = (tolist v1) ++ (tolist v2).
proof.
rewrite /tolist size_catv (range_cat (size v1)); 1,2: smt(size_ge0).
rewrite map_cat; congr.
- apply/eq_in_map => i /mem_range /=; smt(get_catv_l).
rewrite addzC range_addr /= -map_comp /(\o).
apply/eq_in_map => i /mem_range /=; smt(get_catv_r).
qed.

lemma tolist_vectc n c : tolist (vectc n c) = nseq n c. 
proof. 
apply oflist_inj; rewrite tolistK; apply eq_vectorP.
rewrite size_vectc size_oflist size_nseq /= => i i_bound. 
by rewrite get_vectc 1:/# (get_oflist c) ?size_nseq // ?nth_nseq /#.
qed.

(* Distribution of length n vectors sampled using d element-wise *)
op dvector (d: R distr) (n: int) = dmap (dlist d (max 0 n)) oflist.

lemma supp_dvector d v k : 
  0 <= k =>
  (v \in dvector d k) <=> 
  size v = k /\ forall i, 0 <= i < k => v.[i] \in d.
proof.
move=> k_ge0; rewrite supp_dmap; split => [[s []]|[s_v] supp_v].
- rewrite ler_maxr // supp_dlist 1:/# => -[s_s /allP supp_s ->]. 
  rewrite size_oflist s_s /= -s_s => i bound. 
  rewrite (get_oflist witness) // supp_s mem_nth //.
exists (tolist v); rewrite ler_maxr // tolistK /=.
apply/supp_dlist => //; rewrite size_tolist s_v /=; apply/allP => x x_v.
rewrite -(nth_index x _ _ x_v) nth_tolist.
- by rewrite index_ge0 -size_tolist index_mem x_v.
by rewrite supp_v index_ge0 -s_v -size_tolist index_mem.
qed.

lemma size_dvector d n v: v \in dvector d n => size v = max 0 n. 
proof.
move => /supp_dmap[a [a_in ->]]; rewrite size_oflist.
rewrite supp_dlist /# in a_in.
qed.

lemma get_dvector0E (d: R distr) (v: vector) n: size v <> max 0 n => 
  mu1 (dvector d n) v = 0%r.
proof.
move => size_ineq; apply supportPn. 
rewrite -implybF => /size_dvector v_in.
by apply size_ineq.
qed.

lemma dvector1E (d : R distr) (v : vector) : mu1 (dvector d (size v)) v =
  BRM.bigi predT (fun i => mu1 d v.[i]) 0 (size v).
proof.
rewrite -{2}[v]tolistK dmapE /(\o) /pred1. 
rewrite (@mu_eq _ _ (pred1 (tolist v))); 1: smt(oflist_inj).
rewrite dlist1E 1:/# size_tolist max0size /=.
by rewrite BRM.big_mapT /(\o) &BRM.eq_big.
qed.

lemma mu1_dvector_split d i (v: vector): 0 <= i <= size v => 
  mu1 (dvector d i) (subv v 0 i) * 
  mu1 (dvector d (size v - i)) (subv v i (size v)) =
  mu1 (dvector d (size v)) v. 
proof.
move => i_bound; have last_size: size v - i = size (subv v i (size v)).
- rewrite size_subv /#.
have first_size: i = size (subv v 0 i) by rewrite size_subv /#.
rewrite last_size {1}first_size !dvector1E.
rewrite (BRM.big_cat_int i 0 (size v)) 3:2!size_subv; first 2 smt().
have ->: max 0 (size v - i) = size v - i by smt().
have ->: max 0 i = i by smt().
congr.
- apply BRM.eq_big_int => k cont /=. 
  by rewrite get_subv.
- have ->: i = 0 + i by algebra.
  rewrite (BRM.big_addn 0) /=. 
  apply BRM.eq_big_int => k cont /=.
  by rewrite get_subv. 
qed.

lemma dvector_uni d n: is_uniform d => is_uniform (dvector d n).
proof.
move => uni_d; apply dmap_uni; 1: apply/(can_inj _ tolist)/oflistK.
by apply dlist_uni.
qed.

lemma dvector_ll d n: is_lossless d => is_lossless (dvector d n).
proof. by move => ll_d; apply/dmap_ll/dlist_ll. qed.

lemma dvector_fu d (v: vector): is_full d => v \in dvector d (size v).
proof.
move=> full_d; rewrite supp_dmap.
exists (tolist v).
rewrite tolistK /= -size_tolist dlist_fu => x /#.
qed.

lemma mu1_dvector_fu (d: R distr) (v: vector): is_funiform d =>
    mu1 (dvector d (size v)) v = (mu1 d witness)^(size v).
proof.
rewrite dvector1E => d_funi.
have ->: (fun (i : int) => mu1 d v.[i]) = fun (_: int) => mu1 d witness.
- rewrite fun_ext => i. 
  apply d_funi.
have: 0 <= size v by exact size_ge0. 
move: (size v); elim/ge0ind => [/# | _ | n bound IH _].
- by rewrite range_geq // BRM.big_nil RField.expr0.
- by rewrite BRM.big_int_recr // RField.exprS // RField.mulrC IH.
qed.

end Vectors.

export Vectors.

(* -------------------------------------------------------------------------- *)
theory Matrices.

type prematrix = (int -> int -> R) * int * int.

op mclamp (pm: prematrix): prematrix =
  ((fun i j => if 0 <= i < pm.`2 /\ 0 <= j < pm.`3 then pm.`1 i j else zeror),
   max 0 pm.`2, max 0 pm.`3).

lemma nosmt mclamp_idemp pm: mclamp (mclamp pm) = mclamp pm.
proof. by rewrite /mclamp /#. qed.

hint simplify mclamp_idemp.

op eqv (pm1 pm2: prematrix) = mclamp pm1 = mclamp pm2.

clone import Quotient.EquivQuotient as QuotientMat with
  type T <- prematrix,
  op eqv <- eqv
  rename [type] "qT" as "matrix"
  proof * by smt().

type matrix = QuotientMat.matrix.

op tofunm m = mclamp (repr m).

op offunm pm = pi (mclamp pm).

lemma tofunmK : cancel tofunm offunm.
proof. 
rewrite /tofunm /offunm /cancel => m /=. 
have ->: pi (mclamp (repr m)) = pi (repr m) by rewrite -eqv_pi /eqv.
apply reprK.
qed.

lemma offunmK pm: tofunm (offunm pm) = mclamp pm.
proof. by rewrite /tofunm /offunm eqv_repr. qed.

hint simplify offunmK.

lemma matrixW (P : matrix -> bool) : (forall pm, P (offunm pm)) =>
  forall m, P m by smt(tofunmK).

(* Number of rows and columns of matrices *)
op rows m = (tofunm m).`2.

op cols m = (tofunm m).`3.

abbrev size m = (rows m, cols m).

lemma nosmt rows_ge0 m: 0 <= rows m by smt().

lemma nosmt cols_ge0 m: 0 <= cols m by smt().

lemma rows_offunm f r c: rows (offunm (f, r, c)) = max 0 r by done. 

lemma cols_offunm f r c: cols (offunm (f, r, c)) = max 0 c by done.

lemma size_offunm f r c: size (offunm (f, r, c)) = (max 0 r, max 0 c) by done.

hint simplify rows_offunm, cols_offunm.

lemma nosmt max0rows m: max 0 (rows m) = rows m by smt().

lemma nosmt max0cols m: max 0 (cols m) = cols m by smt(). 

hint simplify max0rows, max0cols.

(* Getting the element at position i, j *)
op get (m : matrix) (ij : int * int) = (tofunm m).`1 ij.`1 ij.`2.

abbrev "_.[_]" m ij = get m ij.

abbrev mrange m (i j : int) = 0 <= i < rows m /\ 0 <= j < cols m.

lemma get_offunm f r c (i j : int) : mrange (offunm (f, r, c)) i j => 
  (offunm (f, r, c)).[i, j] = f i j.
proof. rewrite /get /= /mclamp /= /#. qed.

lemma nosmt getm0E (m : matrix) (i j : int) : !mrange m i j => m.[i, j] = zeror.
proof. by smt(). qed.

lemma offunm0E f r c (i j: int) : !(0 <= i < r /\ 0 <= j < c) =>
    (offunm (f, r, c)).[i, j] = zeror.
proof. move => idx_out. rewrite getm0E /#. qed.
    
lemma eq_matrixP (m1 m2 : matrix) : (m1 = m2) <=> 
  size m1 = size m2 /\ (forall i j, mrange m1 i j => m1.[i, j] = m2.[i, j]).
proof. 
split=> [-> // | @/get /= eq_mi]. 
have: tofunm m1 = tofunm m2 by rewrite /tofunm /mclamp /#. 
smt(tofunmK).
qed.

(* Special matrices *)

(* Constant valued matrix with r rows and c columns *)
op matrixc (rows cols: int) (c : R) = offunm ((fun _ _ => c), rows, cols).

lemma nosmt rows_matrixc cst r c: rows (matrixc r c cst) = max 0 r by done.

lemma nosmt cols_matrixc cst r c: cols (matrixc r c cst) = max 0 c by done.

hint simplify rows_matrixc, cols_matrixc.

lemma nosmt size_matrixc cst r c: size (matrixc r c cst) = (max 0 r, max 0 c).
proof. by done. qed.

lemma get_matrixc cst c r i j: mrange (matrixc r c cst) i j => 
  (matrixc r c cst).[i, j] = cst. 
proof. move => bound. by apply get_offunm. qed.

(* Matrix with the values of v on the diagonal and zeror off the diagonal *)
op diagmx (v : vector) =
  offunm ((fun i j => if i = j then v.[i] else zeror), size v, size v).

lemma rows_diagmx v: rows (diagmx v) = size v by rewrite /diagmx /#.

lemma cols_diagmx v: cols (diagmx v) = size v by rewrite /diagmx /#.

hint simplify rows_diagmx, cols_diagmx.

lemma size_diagmx v: size (diagmx v) = (size v, size v) by done.

lemma get_diagmx v i j: (diagmx v).[i, j] = if i = j then v.[i] else zeror.
proof. 
case (mrange (diagmx v) i j) => /= [[i_bound j_bound] | not_bound].
- by rewrite get_offunm.
rewrite getm0E //.
case (i = j) => [idx_eq | //].
rewrite getv0E /#.
qed.

hint simplify get_diagmx.

(* Matrix with constant values on the diagonal *)
abbrev diagc n (c : R) = diagmx (vectc n c).

lemma get_diagc n c i j: 
  (diagc n c).[i, j] = if i = j /\ 0 <= i < n then c else zeror.
proof. 
case (i = j) => /= -> //=. 
case (0 <= j < n) => j_bound; 1: by rewrite get_vectc.
apply getv0E => /= /#.
qed.

hint simplify get_diagc.

(* n by n identity matrix *)
abbrev onem n = diagc n oner.

lemma get_onem i j n: mrange (onem n) i j => (onem n).[i, j] =
  if i = j then oner else zeror.
proof. 
rewrite get_diagc => bound.
suff: (0 <= i < n) by move => ->.
rewrite rows_diagmx cols_diagmx /= in bound => /#.
qed.

lemma onem0E i j n: i <> j => (onem n).[i, j] = zeror.
proof. move=> ne_ij. by rewrite get_diagmx ne_ij. qed.

(* r by c zero matrix*)
abbrev zerom r c  = matrixc r c zeror.

lemma get_zerom r c i j: (zerom r c).[i, j] = zeror. 
proof.
case (mrange (zerom r c) i j) => range; last by rewrite getm0E.
by rewrite get_matrixc.
qed.

hint simplify get_zerom.

(* 0 by 0 matrix also used as error state when there is a size mismatch *)
op emptym = zerom 0 0.

lemma rows_emptym: rows emptym = 0 by done.

lemma cols_emptym: cols emptym = 0 by done.

lemma size_emptym: size emptym = (0, 0) by done.

lemma get_emptym i j: emptym.[i,j] = zeror by done.

hint simplify rows_emptym, cols_emptym, get_emptym.

lemma emptym_unique m: size m = (0, 0) => m = emptym.
proof. 
move => [rows_m cols_m]. 
apply eq_matrixP => /=. 
rewrite rows_m cols_m /= => i j /#.
qed.

(* lifting functions to matrices *)
op mapm (f : R -> R) (m : matrix) : matrix = 
  offunm (fun i j => f m.[i, j], rows m, cols m).

lemma size_mapm f v : size (mapv f v) = size v by [].

lemma get_mapm f (v : vector) i : 
  (0 <= i < size v) => (mapv f v).[i] = f v.[i].
proof. by move => bound; rewrite get_offunv. qed.

(* Matrix addition *)
op (+) (m1 m2 : matrix) =
  offunm (fun i j => m1.[i, j] + m2.[i, j], 
         max (rows m1) (rows m2), 
         max (cols m1) (cols m2)).

lemma rows_addm (m1 m2: matrix): rows (m1 + m2) = max (rows m1) (rows m2).
proof. rewrite /(+) rows_offunm /#. qed.

lemma cols_addm (m1 m2: matrix): cols (m1 + m2) = max (cols m1) (cols m2).
proof. rewrite /(+) cols_offunm /#. qed.

lemma size_addm (m1 m2: matrix): size m1 = size m2 => size (m1 + m2) = size m1.
proof. move => size_eq; rewrite rows_addm cols_addm /#. qed.

lemma get_addm (m1 m2 : matrix) i j: (m1 + m2).[i, j] = m1.[i, j] + m2.[i, j].
proof.
case: (mrange (m1 + m2) i j) => rg_i. 
- rewrite /(+) /= get_offunm 2:/#; 1:smt(rows_addm cols_addm).
- rewrite !getm0E; first 3 smt(rows_addm cols_addm). 
  by rewrite addr0.
qed.

(* Matrix additive inverse *)
op[opaque] [-] (m: matrix) = offunm (fun i j => -m.[i, j], rows m, cols m).

lemma rows_neg (m: matrix): rows (-m) = rows m by rewrite /([-]).

lemma cols_neg (m: matrix): cols (-m) = cols m by rewrite /([-]).

hint simplify rows_neg, cols_neg.

lemma size_neg (m: matrix): size (-m) = size m by rewrite /([-]).

lemma getmN (m : matrix) i j : (-m).[i, j] = - m.[i, j].
proof.
case: (mrange m i j) => rg_i; 1: by rewrite /([-]) get_offunm.
by rewrite getm0E // getm0E // oppr0.
qed.

hint simplify getmN.

lemma emptymN: -emptym = emptym by apply emptym_unique.

hint simplify emptymN.

(* Module like properties of matrices *)
lemma addmA: associative Matrices.( + ).
proof. 
move => m1 m2 m3; apply eq_matrixP.
split => [| i j bound].
- rewrite !(rows_addm, cols_addm) 1:/#.
by rewrite 4!get_addm addrA.
qed.

lemma addmC: commutative Matrices.( + ).
proof.
move => m1 m2.
apply eq_matrixP. 
rewrite !rows_addm !cols_addm.
split => [/# | i j _].
by rewrite 2!get_addm addrC.
qed.

lemma add0m m: zerom (rows m) (cols m) + m = m.
proof.
apply eq_matrixP.
rewrite size_addm /= 1://.
move => i j bound; rewrite get_addm //=.
apply add0r.
qed.

lemma lin_add0m m r c: 
  r = rows m => c = cols m => zerom r c + m = m. 
proof. move => -> ->; apply add0m. qed.

lemma addm0 m: m + zerom (rows m) (cols m) = m by rewrite addmC add0m.

lemma lin_addm0 m r c: 
  r = rows m => c = cols m => m + zerom r c = m. 
proof. move => -> ->; apply addm0. qed.

abbrev (-) (m1 m2: matrix) = m1 + (-m2).

lemma addmN (m: matrix) : m - m = zerom (rows m) (cols m).
proof.
apply eq_matrixP. 
split => [| i j bound]; 1: by rewrite size_addm.
by rewrite get_addm //= addrN.
qed.

lemma addNm (m: matrix): (-m) + m = zerom (rows m) (cols m).
proof. by rewrite addmC addmN. qed.

lemma oppmD (m1 m2: matrix): -(m1 + m2) = -m1 + -m2.
proof.
rewrite eq_matrixP; split => [| i j bound].
- by rewrite /= 2!rows_addm 2!cols_addm.
- by rewrite /= 2!get_addm opprD.
qed.

lemma oppmK (m: matrix): - (- m) = m.
proof.
rewrite eq_matrixP. 
split => [| i j bound]; 1: by rewrite !size_neg.
by rewrite !getmN opprK.
qed.

lemma sub_eqm (x y z : matrix): 
   size x = size z => size y = size z => 
   x - z = y <=> x = y + z.
proof. smt(addmA addmC addNm add0m). qed.

(* matrix transposition *)
op trmx (m : matrix) = offunm (fun i j => m.[j, i], cols m, rows m).

lemma rows_tr m: rows (trmx m) = cols m by done.

lemma cols_tr m: cols (trmx m) = rows m by done.

hint simplify rows_tr, cols_tr.

lemma size_tr m: size (trmx m) = (cols m, rows m) by done.

lemma trmxE (m : matrix) i j : (trmx m).[i, j] = m.[j, i].
proof. 
case: (mrange m j i) => bound. 
- rewrite get_offunm /#.
- rewrite getm0E /#.
qed.

hint simplify trmxE.

lemma trmxK: involutive trmx.
proof.
move => m. 
by apply eq_matrixP.
qed.

lemma trmx_inj: injective trmx by apply inv_inj; apply trmxK.

lemma trmxD (m1 m2 : matrix) : trmx (m1 + m2) = trmx m1 + trmx m2.
proof.
apply eq_matrixP. 
split => [/= | i j bound]; 1: smt(rows_addm cols_addm cols_tr rows_tr).
rewrite trmxE 2!get_addm /#. 
qed.

hint simplify trmxD.

lemma trmxN (m: matrix): trmx (-m) = - trmx m by apply eq_matrixP.

hint simplify trmxN.

lemma trmx_empty: trmx emptym = emptym by apply emptym_unique.

lemma trmx1 n: trmx (onem n) = (onem n).
proof.
apply eq_matrixP => i j bound /= /#. 
qed.

hint simplify trmx1, trmx_empty.

lemma trmx_matrixc c n m: trmx (matrixc n m c) = matrixc m n c.
proof.
rewrite eq_matrixP /= => i j bound. 
rewrite !get_matrixc /#.
qed.

hint simplify trmx_matrixc.

(* Gets the n-th row of m as a vector *)
op row m n = offunv (fun i => m.[n, i], cols m).

lemma size_row m i: size (row m i) = cols m by done.

lemma get_row m i j: (row m i).[j] = m.[i, j].
proof.
case (0 <= j < cols m) => bound; 1: by apply get_offunv.
rewrite getv0E // getm0E /#.
qed.

hint simplify size_row, get_row.

lemma row0E m i: !(0 <= i < rows m) => row m i = zerov (cols m). 
proof. move => n_bound; rewrite eq_vectorP => j bd /=. rewrite getm0E /#. qed.

lemma rowN m i: - (row m i) = row (-m) i by rewrite eq_vectorP.

lemma rowD m1 m2 n: (row (m1 + m2) n) = row m1 n + row m2 n.
proof. 
apply eq_vectorP.
rewrite size_row cols_addm size_addv 2!size_row /= => i _.
by rewrite get_addv get_addm /=.
qed.

lemma row_matrixc m n c i: 0 <= i < m => row (matrixc m n c) i = vectc n c.
proof.
move => i_bound; apply eq_vectorP => /= j j_bound.
rewrite get_vectc 1:/# get_matrixc /#.
qed.

(* Gets the n-th column of m as a vector *)
op col m n = offunv (fun i => m.[i, n], rows m).

lemma size_col m n: size (col m n) = rows m by done.

lemma get_col m i j: (col m j).[i] = m.[i, j].
proof.
case (0 <= i < rows m) => bound; 1: by apply get_offunv.
rewrite getv0E // getm0E /#.
qed.

hint simplify size_col, get_col.

lemma row_trmx m n: row (trmx m) n = col m n by rewrite eq_vectorP.

lemma col_trmx m n: col (trmx m) n = row m n by rewrite eq_vectorP.

hint simplify row_trmx, col_trmx.

lemma col0E m n: !(0 <= n < cols m) => col m n = zerov (rows m).
proof.
rewrite -trmxK cols_tr col_trmx => bound. 
rewrite rows_tr. 
by apply row0E.
qed.

lemma colD m1 m2 i: (col (m1 + m2) i) = col m1 i + col m2 i.
proof. smt(trmxK col_trmx trmxD rowD). qed.

lemma colN (m : matrix) i : - (col m i) = col (-m) i by rewrite eq_vectorP.

lemma col_matrixc m n c i: 0 <= i < n => col (matrixc m n c) i = vectc m c.
proof.
move => bound; rewrite -(trmxK (matrixc _ _ _)) col_trmx /=. 
by apply row_matrixc.
qed.

(* Matrix multiplication *)
op[opaque] ( * ) (m1 m2 : matrix) =
  offunm (fun i j => dotp (row m1 i) (col m2 j), rows m1, cols m2).

lemma rows_mulmx m1 m2: rows (m1 * m2) = rows m1 by rewrite /( * ).

lemma cols_mulmx m1 m2: cols (m1 * m2) = cols m2 by rewrite /( * ).

hint simplify rows_mulmx, cols_mulmx.

lemma size_mulmx m1 m2: size (m1 * m2) = (rows m1, cols m2) by done.

lemma get_mulmx m1 m2 i j: (m1 * m2).[i,j] = dotp (row m1 i) (col m2 j).
proof.
case (mrange (m1 * m2) i j) => /= [bound | /negb_and bound].
- by rewrite /( * ) /= get_offunm.
rewrite getm0E 1:/#.
elim bound => boundN.
- suff: row m1 i = zerov (cols m1) by move => ->. 
  rewrite row0E /#.
- suff: col m2 j = zerov (rows m2) by move => ->. 
  rewrite col0E /#.
qed.

lemma mulmx_emptym: emptym * emptym = emptym by apply/emptym_unique/size_mulmx.

hint simplify mulmx_emptym.

lemma trmxM (m1 m2 : matrix): trmx (m1 * m2) = trmx m2 * trmx m1.
proof.
apply eq_matrixP.
rewrite size_mulmx /= => i j bound.
by rewrite !get_mulmx dotpC.
qed.

hint simplify trmxM.

lemma mulmxDl (m1 m2 m : matrix): (m1 + m2) * m = (m1 * m) + (m2 * m).
proof.
apply eq_matrixP. 
split => [| i j bound]; 1: smt(cols_addm rows_addm rows_mulmx cols_mulmx).
by rewrite get_addm 3!get_mulmx rowD dotpDl.
qed.

lemma mulmxDr (m1 m2 m : matrix): m * (m1 + m2) = (m * m1) + (m * m2).
proof. apply trmx_inj => /=. apply mulmxDl => /#. qed.

lemma mulmxA m1 m2 m3: m1 * (m2 * m3) = m1 * m2 * m3.
proof.
apply eq_matrixP. 
split => [| i j bound]; 1: smt(rows_mulmx cols_mulmx).
rewrite !get_mulmx dotpE /= (big_range0r (rows m2)) 1:/#. 
- smt(getm0E mulr0 mul0r rows_mulmx cols_mulmx row0E).
rewrite (eq_bigr _ _ (fun k =>
    bigi predT (fun (l : int) => m1.[i,k] * (m2.[k,l] * m3.[l,j])) 0 (rows m3))).
- move => k /= _.
  rewrite get_mulmx dotpE mulr_sumr /= (big_range0r (rows m3)) // 1:/#.
  smt(getm0E mulr0 mul0r).
rewrite exchange_big /= dotpE /=.
rewrite [bigi _ _ _ (max _ _)](big_range0r (rows m3)) // 1:/#. 
- smt(getm0E mulr0 mul0r).
apply eq_bigr => k _ /=; rewrite get_mulmx dotpE mulr_suml /=.
rewrite [bigi _ _ _ (max _ _)](big_range0r (rows m2)) 1:/#.
- smt(getm0E mulr0 mul0r).
by apply eq_bigr => l _ /=; rewrite mulrA.
qed.

lemma mulmxm0 n m: m * zerom (cols m) n = zerom (rows m) n.
proof.
rewrite eq_matrixP /= => i j /= bound.
by rewrite get_mulmx /= col_matrixc 1:/# dotpv0.
qed.

lemma mul0m n m: zerom n (rows m) * m = zerom n (cols m).
proof. apply trmx_inj => /=. apply mulmxm0. qed.

lemma mulmx1 m: m * onem (cols m) = m.
proof.
apply eq_matrixP. 
rewrite size_mulmx 1:// /= => i j bound.
rewrite get_mulmx dotpE (bigD1 _ _ j) 1:mem_range 1:/# 1:range_uniq.
rewrite big1 /= => [k |]; last by rewrite get_vectc 1:/# mulr1 addr0.
rewrite /predC1 => -> /=. 
by rewrite mulr0.
qed.

lemma mul1mx m: onem (rows m) * m = m by apply trmx_inj => /=; apply mulmx1.

(* Turns row vector into matrix *)
op rowmx (v: vector) = offunm ((fun _ i => v.[i]), 1, size v).

lemma rows_rowmx v: rows (rowmx v) = 1 by done.

lemma cols_rowmx v: cols (rowmx v) = size v by done.

hint simplify rows_rowmx, cols_rowmx.

lemma size_rowmx v: size (rowmx v) = (1, size v) by done.

lemma get_rowmx v i: (rowmx v).[0,i] = v.[i].
proof.
case (0 <= i < size v) => bound.
- rewrite get_offunm //= /#. 
- by rewrite getm0E //= getv0E.
qed.

hint simplify get_rowmx.

lemma rowK v: row (rowmx v) 0 = v by rewrite eq_vectorP.

hint simplify rowK.

lemma rowmx_row m: rows m = 1 => rowmx (row m 0) = m.
proof.
move => rws; rewrite eq_matrixP /= rws => i j bound.
case (i = 0) => [->//| i_neq0].
rewrite !getm0E /#.
qed.

lemma rowmxD (v1 v2: vector): rowmx (v1 + v2) = rowmx v1 + rowmx v2.
proof.
rewrite -(rowmx_row (rowmx (v1 + v2))) 1://. 
by rewrite -(rowmx_row (rowmx v1 + rowmx v2)) 1:rows_addm //= rowD.
qed.

lemma rowmxN (v : vector) : - (rowmx v) = rowmx (-v).
proof. rewrite eq_matrixP /= => i j; case (i = 0) => [->// | /#]. qed.

lemma rowmxc n c: rowmx (vectc n c) = matrixc 1 n c.
proof.
apply eq_matrixP => /= i j bound. 
case (i = 0) => [->/=|i_neq0 /#].
rewrite get_matrixc /= 1:/# get_vectc /#.
qed.

hint simplify rowmxc.

(* turns column vector into matrix *)
op colmx (v: vector) = offunm ((fun i _ => v.[i]), size v, 1).

lemma rows_colmx v: rows (colmx v) = size v by done.

lemma cols_colmx v: cols (colmx v) = 1 by done.

hint simplify rows_colmx, cols_colmx.

lemma size_colmx v: size (colmx v) = (size v, 1) by done.

lemma get_colmx v i: (colmx v).[i,0] = v.[i].
proof.
case (0 <= i < size v) => bound; first rewrite get_offunm //= /#.
by rewrite getm0E //= getv0E.
qed.

hint simplify get_colmx.

lemma colK v: col (colmx v) 0 = v by rewrite eq_vectorP.

hint simplify colK.

lemma colmx_col m: cols m = 1 => colmx (col m 0) = m.
proof.
move => cls; rewrite eq_matrixP /= cls => i j bound.
case (j = 0) => [->// | /#].
qed.

lemma trmx_rowmx v: trmx (rowmx v) = colmx v.
proof.
rewrite eq_matrixP /= => i j bound.
case (j = 0) => [->// | /#].
qed.

hint simplify trmx_rowmx.

lemma trmx_colmx v: trmx (colmx v) = rowmx v.
proof. by apply trmx_inj => /=; rewrite trmxK. qed.

hint simplify trmx_colmx.

lemma rowmx_col m: cols m = 1 => rowmx (col m 0) = trmx m.
proof. move => rws; apply trmx_inj => /=. by rewrite trmxK colmx_col. qed.

lemma colmx_row m: rows m = 1 => colmx (row m 0) = trmx m.
proof. move => rws; apply trmx_inj => /=. by rewrite trmxK rowmx_row. qed.

lemma colmxD (v1 v2: vector):
    colmx (v1 + v2) = colmx v1 + colmx v2.
proof. apply trmx_inj => /=. exact rowmxD. qed.

lemma colmxN (v : vector) : - (colmx v) = colmx (-v).
proof. rewrite eq_matrixP /= => i j; case (j = 0) => [->// | /#]. qed.

lemma colmxc n c: colmx (vectc n c) = matrixc n 1 c by apply trmx_inj.

hint simplify colmxc.

(* Matrix and vector multiplication *)
op[opaque] mulmxv m v = col (m * colmx v) 0.

abbrev ( *^ ) (m : matrix) (v : vector) : vector = mulmxv m v.

lemma mulmxvE m v: m *^ v = col (m * colmx v) 0 by rewrite /mulmxv.

lemma size_mulmxv m (v: vector): size (m *^ v) = rows m by rewrite /mulmxv.

lemma get_mulmxv m v i: (m *^ v).[i] = dotp (row m i) v.
proof. by rewrite mulmxvE /= get_mulmx. qed.

lemma colmx_mulmxv (m : matrix) (v : vector) : 
  colmx (m *^ v) = m * colmx v.
proof. by rewrite mulmxvE colmx_col. qed.

lemma mulmxvDl (m1 m2 : matrix) (v : vector) :
  (m1 + m2) *^ v = m1 *^ v + m2 *^ v.
proof. by rewrite !mulmxvE -colD mulmxDl. qed.

lemma mulmxvDr (m : matrix) (v1 v2 : vector) :
  m *^ (v1 + v2) = m *^ v1 + m *^ v2.
proof. by rewrite !mulmxvE -colD colmxD 1:// mulmxDr. qed.

lemma mulmxvA (m1 m2 : matrix) (v : vector) :
  m1 *^ (m2 *^ v) = (m1 * m2) *^ v.
proof. by rewrite !mulmxvE colmx_col // mulmxA. qed.

lemma mulmxv0 (m : matrix) : m *^ (zerov (cols m)) = zerov (rows m).
proof. by rewrite mulmxvE colmxc /= mulmxm0 col_matrixc. qed.

lemma mulmx1v (v : vector): onem (size v) *^ v = v.
proof.
rewrite -{3}colK mulmxvE.
congr.
have ->: size v = rows (colmx v) by done.
by rewrite mul1mx.
qed.

lemma mul_colmxc (v:vector) c: (colmx v) *^ vectc 1 c = c ** v.
proof.
rewrite eq_vectorP size_scalarv size_mulmxv /= /max 1:// => i bound.
rewrite get_scalarv get_mulmxv dotpE /= /max /=.
by rewrite Big.BAdd.big_int1 /= get_vectc // mulrC.
qed.

lemma mulmx_scalarv (m : matrix) (s : t) (v : vector) :
  m *^ (s ** v) = s ** (m *^ v).
proof.
apply eq_vectorP; rewrite size_mulmxv size_scalarv size_mulmxv => /= i Hi.
by rewrite get_scalarv !get_mulmxv dotp_scalarv_r.
qed.

(* Vector and matrix multiplication *)
op[opaque] mulvmx v m = row (rowmx v * m) 0.

abbrev ( ^* ) (v : vector) (m : matrix) : vector = mulvmx v m.

lemma mulvmxE v m: v ^* m = row (rowmx v * m) 0 by rewrite /mulvmx.

lemma size_mulvmx (v: vector) m: size (v ^* m) = cols m by rewrite /mulvmx.

lemma get_mulvmx v m i: (v ^* m).[i] = dotp v (col m i).
proof. by rewrite /mulvmx /= get_mulmx. qed.

lemma mulmxTv (m : matrix) (v : vector) : (trmx m) *^ v = v ^* m.
proof. by rewrite mulvmxE -col_trmx trmxM mulmxvE. qed.

hint simplify mulmxTv.

lemma mulvmxT (v : vector) (m : matrix) : v ^* (trmx m) = m *^ v.
proof. by rewrite mulvmxE-{2}(trmxK m) /#. qed.

hint simplify mulvmxT.

lemma mulvmxDr (v : vector) (m1 m2 : matrix) :
  v ^* (m1 + m2) = v ^* m1 + v ^* m2.
proof. rewrite -mulmxTv trmxD mulmxvDl /#. qed.

lemma mulvmxDl (v1 v2 : vector) (m : matrix) :
  (v1 + v2) ^* m = v1 ^* m + v2 ^* m.
proof. by rewrite -mulmxTv mulmxvDr. qed.

lemma mulvmxA (v : vector) (m1 m2 : matrix) :
  v ^* (m1 * m2) = (v ^* m1) ^* m2.
proof. rewrite -(trmxK (m1 * m2)) trmxM mulvmxT -mulmxvA /#. qed.

lemma mulv0mx (m : matrix): zerov (rows m) ^* m = zerov (cols m).
proof. by rewrite -mulmxTv -{2}trmxK rows_tr mulmxv0 -cols_tr trmxK. qed.

lemma rowmx_mulvmx (v : vector) (m : matrix) : 
  rowmx (v ^* m) = rowmx v * m.
proof. by rewrite -trmx_colmx -mulmxTv colmx_mulmxv 1:// trmxM trmxK. qed.

lemma mulvmx1 (v : vector) : v ^* onem (size v) = v.
proof. by rewrite -mulmxTv trmx1 mulmx1v. qed.

lemma dotp_eqv_mul v1 v2 : dotp v1 v2 = (rowmx v1 * colmx v2).[0,0].
proof. by rewrite get_mulmx. qed.

lemma dotp_mulmxv m (v1 v2: vector): dotp v1 (m *^ v2) = dotp (v1 ^* m) v2.
proof. 
by rewrite 2!dotp_eqv_mul mulmxvE colmx_col 1:// mulvmxE rowmx_row 1:// mulmxA. 
qed.

lemma mul_rowmxc (v:vector) c: vectc 1 c ^* (rowmx v) = c ** v.
proof. by rewrite -(trmxK (rowmx v)) mulvmxT /= mul_colmxc. qed.

(* Sideways matrix concatenation - aka row block matrices *)
op catmr (m1 m2: matrix) = 
  offunm ((fun i j => m1.[i, j] + m2.[i, j-cols m1]), 
          max (rows m1) (rows m2), cols m1 + cols m2).

abbrev ( || ) m1 m2 = catmr m1 m2.

lemma rows_catmr (m1 m2: matrix): rows (m1 || m2) = max (rows m1) (rows m2).
proof. rewrite rows_offunm /#. qed.

lemma cols_catmr (m1 m2: matrix): cols (m1 || m2) = cols m1 + cols m2.
proof. rewrite cols_offunm /#. qed.

lemma size_catmr (m1 m2: matrix): 
  size (m1 || m2) = (max (rows m1) (rows m2), cols m1 + cols m2).
proof. rewrite rows_offunm cols_offunm /#. qed.

lemma get_catmr (m1 m2: matrix) i j: 
  (m1 || m2).[i, j] = m1.[i, j] + m2.[i, j-cols m1].
proof.
rewrite /catmr /=. 
case (mrange (m1 || m2) i j) => range.
- rewrite get_offunm //.
- rewrite !getm0E /=; first 3 smt(size_catmr).
  by rewrite addr0.
qed.

lemma col_catmrL m1 m2 i: rows m1 = rows m2 => i < cols m1 =>
  col (m1 || m2) i = col m1 i.
proof.
move => rows_eq bound1; rewrite eq_vectorP /=. 
split => [| j bound2]; 1: by rewrite rows_catmr /#.
rewrite get_catmr // (getm0E m2) 2:addr0; smt(rows_catmr).
qed.

lemma col_catmrR m1 m2 i: rows m1 = rows m2 => cols m1 <= i =>
  col (m1 || m2) i = col m2 (i - cols m1).
proof.
move => rows_eq bound; rewrite eq_vectorP /=. 
split => [| j bound2]; 1: by rewrite rows_catmr /#.
rewrite get_catmr // (getm0E m1) 2:add0r; smt(rows_catmr).
qed.

lemma row_catmr m1 m2 i: row (m1 || m2) i = (row m1 i || row m2 i).
proof.
rewrite eq_vectorP /=. 
rewrite cols_catmr // size_catv /= => j bound.
by rewrite get_catmr // get_catv'.
qed.

lemma catmrA (m1 m2 m3: matrix): ((m1 || m2) || m3) = (m1 || (m2 || m3)).
proof. 
rewrite eq_matrixP. 
split => [| i j bound]; 1: smt(size_catmr).
rewrite 4!get_catmr cols_catmr // addrA. 
algebra.
qed.

lemma catmrDr (m1 m2 m3: matrix): m1 * (m2 || m3) = ((m1 * m2) || (m1 * m3)).
proof.
rewrite eq_matrixP.
rewrite rows_mulmx cols_mulmx cols_catmr.
split => [| i j bound]; 1: smt(size_mulmx size_catmr).
rewrite get_catmr 3!get_mulmx. 
case (j < cols m2) => bound2.
- rewrite [col m3 _]col0E 1:/# dotpv0 addr0 !dotpE 2!size_col rows_catmr.
  rewrite (big_cat_int (max (cols m1) (rows m2))) 1:/# size_row 1:/#.
  rewrite [bigi _ _ (max _ _) _]big1_seq => [k [_ /mem_range bound3] /=|].
  + by rewrite get_catmr [_ m2 _]getm0E 1:/# add0r [_ m3 _]getm0E 1:/# mulr0.
  rewrite addr0.
  apply eq_bigr => k _ /=.
  by rewrite get_catmr [_ m3 _]getm0E 1:/# addr0.
- rewrite [col m2 _]col0E 1:/# dotpv0 add0r !dotpE 2!size_col rows_catmr.
  rewrite (big_cat_int (max (cols m1) (rows m3))) 1:/# size_row 1:/#.
  rewrite [bigi _ _ (max _ _) _]big1_seq => [k [_ /mem_range bound3] /=|].
  + by rewrite get_catmr [_ m3 _]getm0E 1:/# addr0 [_ m2 _]getm0E 1:/# mulr0.
  rewrite addr0.
  apply eq_bigr => k _ /=.
  by rewrite get_catmr [_ m2 _]getm0E 1:/# add0r.
qed.

lemma rowmx_catmr (v1 v2: vector): rowmx (v1 || v2) = (rowmx v1 || rowmx v2).
proof.
apply eq_matrixP.
rewrite size_catmr 1:rows_rowmx 1://.
rewrite !rows_rowmx !cols_rowmx size_catv /= => i j bound.
rewrite get_catmr //. 
have ->: (i = 0) by smt().
by rewrite !get_rowmx get_catv' cols_rowmx. 
qed.

lemma addm_catmr (m1 m2 m3 m4: matrix): cols m1 = cols m3 =>
  (m1 || m2) + (m3 || m4) = ((m1 + m3) || (m2 + m4)).
proof.
move => cols_eq.
rewrite eq_matrixP. rewrite size_catmr 3!rows_addm 3!cols_addm 2!rows_catmr.
rewrite 2!cols_catmr.
split => [/# | i j bound].
rewrite get_addm 3!get_catmr 2!get_addm cols_addm /max cols_eq /=.
smt(addrA addrC).
qed.

lemma oppm_catmr (m1 m2: matrix): - (m1 || m2) = (-m1 || -m2).
proof.
apply eq_matrixP => /=.
split => [// | i j bound].
by rewrite 2!get_catmr // /= opprD. 
qed.

lemma mulmxv_cat (m1 m2 : matrix) (v1 v2 : vector): cols m1 = size v1 => 
  (m1 || m2) *^ (v1 || v2) = (m1 *^ v1) + (m2 *^ v2).
proof. 
move => cols_eq_size; apply eq_vectorP.
rewrite size_addv 3!size_mulmxv rows_catmr /= => i bound.
by rewrite get_addv 3!get_mulmxv row_catmr dotp_catv.
qed.

(* Downwards matrix concatenation - aka column block matrices *)
op catmc (m1 m2: matrix) =
  offunm ((fun i j => m1.[i, j] + m2.[i-rows m1, j]), 
          rows m1 + rows m2, max (cols m1) (cols m2)).

abbrev ( / ) m1 m2 = catmc m1 m2.

lemma cols_catmc (m1 m2: matrix): cols (m1 / m2) = max (cols m1) (cols m2).
proof. rewrite cols_offunm /#. qed.

lemma rows_catmc (m1 m2: matrix): rows (m1 / m2) = rows m1 + rows m2.
proof. rewrite rows_offunm /#. qed.

lemma size_catmc (m1 m2: matrix):
  size (m1 / m2) = (rows m1 + rows m2, max (cols m1) (cols m2)).
proof. rewrite cols_offunm rows_offunm /#. qed.

lemma get_catmc (m1 m2: matrix) i j:
  (m1 / m2).[i, j] = m1.[i, j] + m2.[i-rows m1, j].
proof.
case (mrange (m1 / m2) i j) => range.
- rewrite get_offunm /=; smt(size_catmc).
- rewrite !getm0E /= 4:addr0; smt(size_catmc).
qed.

lemma catmcT (m1 m2: matrix): trmx (m1 / m2) = (trmx m1 || trmx m2).
proof.
rewrite eq_matrixP /= cols_catmc rows_catmc rows_catmr cols_catmr /= => i j bnd.
by rewrite get_catmr get_catmc.
qed.

hint simplify catmcT.

lemma catmrT (m1 m2: matrix): trmx (m1 || m2) = trmx m1 / trmx m2.
proof. apply trmx_inj. by rewrite /= 3!trmxK. qed.

hint simplify catmrT.

lemma row_catmcL m1 m2 i: 
  cols m1 = cols m2 => i < rows m1 => row (m1 / m2) i = row m1 i.
proof. move => cols_eq bound; by rewrite -col_trmx /= col_catmrL. qed.

lemma row_catmcR m1 m2 i: cols m1 = cols m2 => rows m1 <= i =>
  row (m1 / m2) i = row m2 (i - rows m1).
proof. move => cols_eq bound; by rewrite -col_trmx /= col_catmrR. qed.

lemma col_catmc m1 m2 i:
  col (m1 / m2) i = (col m1 i || col m2 i).
proof. by rewrite -row_trmx /= row_catmr. qed.

lemma catmcA (m1 m2 m3: matrix): 
  ((m1 / m2) / m3) = (m1 / (m2 / m3)).
proof. apply trmx_inj => /=. exact catmrA. qed.

lemma catmcDl m1 m2 m3:
  (m1 / m2) * m3 = (m1 * m3) / (m2 * m3).
proof. apply trmx_inj => /=. exact catmrDr. qed.

lemma colmx_catmr (v1 v2: vector): colmx (v1 || v2) = (colmx v1 / colmx v2).
proof. apply trmx_inj => /=. exact rowmx_catmr. qed.

lemma addm_catmc (m1 m2 m3 m4: matrix): rows m1 = rows m3  =>
  (m1 / m2) + (m3 / m4) = ((m1 + m3) / (m2 + m4)).
proof.
move => cols_eq; apply trmx_inj => /=. 
rewrite addm_catmr /#.
qed.

lemma oppm_catmc (m1 m2: matrix): - (m1 / m2) = ((-m1) / -m2).
proof. apply trmx_inj => /=. by apply oppm_catmr. qed.

(* Taking a submatrix from row r1 (inclusive) to r2 (exclusive) and
   column c1 (inclusive) to c2 (exclusive). 
   That is m = subm 0 (rows m) 0 (cols m) *)

op subm (m: matrix) (r1 r2 c1 c2: int) = 
  offunm ((fun i j => m.[i+r1,j+c1]), r2-r1, c2-c1).

lemma rows_subm m r1 r2 c1 c2: 
  rows (subm m r1 r2 c1 c2) = max 0 (r2 - r1) by done.

lemma cols_subm m r1 r2 c1 c2: 
  cols (subm m r1 r2 c1 c2) = max 0 (c2 - c1) by done.

hint simplify rows_subm, cols_subm.

lemma size_subm m r1 r2 c1 c2: 
   size (subm m r1 r2 c1 c2) = (max 0 (r2 - r1), max 0 (c2 - c1)) by done.

lemma get_subm (m: matrix) (r1 r2 c1 c2 i j: int): 
  0 <= i < r2 - r1 => 0 <= j < c2 -c1 => 
  (subm m r1 r2 c1 c2).[i,j] = m.[i+r1,j+c1].
proof. by move => i_range j_range; rewrite /subm get_offunm /#. qed.

lemma subm_id m: subm m 0 (rows m) 0 (cols m) = m.
proof. by rewrite eq_matrixP /= => i j bound; rewrite get_subm /#. qed.

lemma submT m r1 r2 c1 c2: trmx (subm m r1 r2 c1 c2) = subm (trmx m) c1 c2 r1 r2.
proof.
apply eq_matrixP => i j bound /=. 
rewrite !get_subm; smt(size_tr size_subm).
qed.

hint simplify submT.

lemma subm_catmrCl m1 m2: subm (m1 || m2) 0 (rows m1) 0 (cols m1) = m1.
proof.
rewrite eq_matrixP => i j bound.
rewrite get_subm; first 2 smt(size_subm).
rewrite get_catmr // (getm0E m2) /= 2:addr0; smt(cols_subm). 
qed.

lemma subm_catmcCl m1 m2: subm (m1 / m2) 0 (rows m1) 0 (cols m1) = m1.
proof. apply trmx_inj => /=. exact subm_catmrCl. qed.

lemma subm_catmrCr m1 m2:
  subm (m1 || m2) 0 (rows m2) (cols m1) (cols m1 + cols m2) = m2.
proof.
rewrite eq_matrixP.
split => [/# | i j bound].
rewrite get_subm; first 2 smt(size_subm).
rewrite get_catmr //= (getm0E m1) /= 2:add0r; smt(cols_subm).
qed.

lemma subm_catmcCr m1 m2:
  subm (m1 / m2) (rows m1) (rows m1 + rows m2) 0 (cols m2) = m2.
proof. apply trmx_inj => /=. exact subm_catmrCr. qed.

lemma rowmx_row_eq_subm r m: rowmx (row m r) = subm m r (r+1) 0 (cols m).
proof.
rewrite eq_matrixP /=. 
split => [/# | i j bound]. 
have ->: i = 0 by smt().
rewrite get_rowmx get_subm 3:get_row /#.
qed.

lemma colmx_col_eq_subm c m: colmx (col m c) = subm m 0 (rows m) c (c+1).
proof. apply trmx_inj => /=. rewrite -row_trmx. apply rowmx_row_eq_subm. qed.

lemma catmr_subm m n: 0 <= n < cols m => 
  (subm m 0 (rows m) 0 n || subm m 0 (rows m) n (cols m)) = m.
proof.
move => n_bound; rewrite eq_matrixP /=.
split => [| i j bound]. 
- smt(rows_catmr cols_catmr size_subm).
rewrite get_catmr // cols_subm /=.
case (j < n) => j_bound.
- rewrite get_subm /=; first 2 smt(size_catmr size_subm).
  rewrite (getm0E (subm _ _ _ _ _)).
  + smt(size_catmr size_subm).
  by rewrite addr0.
- rewrite getm0E; 1: smt(size_catmr size_subm).
  rewrite add0r get_subm; smt(size_catmr size_subm).
qed.

lemma subm_colmx (m: matrix) l : 
  0 <= l => cols m = l + 1 => 
  (subm m 0 (rows m) 0 l || colmx (col m l)) = m.
proof. by move => l_ge0 c_m; rewrite colmx_col_eq_subm -c_m catmr_subm /#. qed.

lemma catmc_subm m n: 0 <= n < rows m => 
  subm m 0 n 0 (cols m) / subm m n (rows m) 0 (cols m) = m.
proof. move => n_bound; apply trmx_inj => /=. by apply catmr_subm. qed.

(* Updating one entry of a matrix *)
op updm (m: matrix) (r c: int) (p: t) = offunm 
  (fun (i j: int) => if i = r /\ j = c then p else m.[i,j], rows m, cols m).

lemma rows_updm m r c p: rows (updm m r c p) = rows m by done.

lemma cols_updm m r c p: cols (updm m r c p) = cols m by done.

lemma size_updm m r c p: size (updm m r c p) = size m by done.

lemma get_updm (m: matrix) (r c i j: int) (p: t): mrange m r c => mrange m i j =>
  (updm m r c p).[i, j] = if (i, j) = (r, c) then p else m.[i, j].
proof.
rewrite /updm. 
case (i = r) => /= [-> | neq_r bound1 bound2]. 
- case (j = c) => /= [-> bound _ | neq_col bound1 bound2]. 
  + by rewrite get_offunm //= get_offunm //= addrC subrK.
  + rewrite get_offunm /#.
- by rewrite get_offunm //= offunm0E. 
qed.

(* Multplication with scalar for matrices and vectors *)
op scalarm (s: t) (m: matrix) = diagc (rows m) s * m.

abbrev ( *** ) s m = scalarm s m.

lemma scalarmE s m: s *** m = diagc (rows m) s * m by done.

lemma rows_scalarm (m: matrix) (s: t): rows (s *** m) = rows m.
proof. by rewrite rows_mulmx. qed.

lemma cols_scalarm (m: matrix) (s: t): cols (s *** m) = cols m.
proof. by rewrite cols_mulmx. qed.

lemma scalarm_mrange m s i j: mrange (s *** m) i j = mrange m i j.
proof. by rewrite rows_mulmx // cols_mulmx. qed.

lemma size_scalarm (m: matrix) (s: t): size (s *** m) = size m.
proof. by rewrite rows_scalarm cols_scalarm. qed.
    
lemma get_scalarm (m: matrix) (s: t) (i j: int):
    (s *** m).[i, j] = m.[i, j] * s.
proof. 
rewrite get_mulmx /dotp 2!size_offunv /=.
case (0 <= i < rows m) => bound.
- rewrite (bigD1 _ _ i) 1:mem_range 1:/# 1:range_uniq.
  rewrite big1 /= 2:get_vectc 2:/# 2:addr0 2:mulrC 2:// => k.
  rewrite /predC1 eq_sym => -> /=.
  apply mul0r.
- rewrite big1 => [k _ /=|].
  + by rewrite getv0E // mul0r.
  by rewrite getm0E 1:/# mul0r.
qed.

lemma scalar0m (m: matrix): zeror *** m = zerom (rows m) (cols m).
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: rewrite size_scalarm /#.
by rewrite get_zerom get_scalarm mulr0.
qed.

lemma scalar1m (m: matrix): oner *** m = m.
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: by rewrite size_scalarm.    
by rewrite get_scalarm mulr1.
qed.

lemma scalarNm (m: matrix) (s: t): (- s) *** m = - (s *** m).
proof.
rewrite eq_matrixP.
split => [| i j bound]; 1: rewrite !size_scalarm /=; 1: smt(size_scalarm).
by rewrite /= !get_scalarm /= mulrN.
qed.

lemma scalarmAs s1 s2 m: s1 *** (s2 *** m) = (s1 * s2) *** m.
proof.
apply eq_matrixP. 
split => [| i j _]; 1: by rewrite !size_scalarm. 
by rewrite 3!get_scalarm -mulrA [s2 * s1]mulrC.
qed.

lemma scalarmD m1 m2 s: s *** (m1 + m2) = s *** m1 + s *** m2.
proof.
apply eq_matrixP.
rewrite rows_scalarm rows_addm cols_scalarm rows_addm. 
rewrite 2!cols_addm /= => i j bound.
by rewrite get_addm get_scalarm get_addm !get_scalarm mulrDl.
qed.

lemma scalarDm m (s1 s2: t): (s1 + s2) *** m = s1 *** m + s2 *** m.
proof.
apply eq_matrixP.
rewrite rows_scalarm cols_scalarm.
split => [| i j bound]; 1: rewrite size_addm !size_scalarm //.
by rewrite get_addm 3!get_scalarm // mulrDr.
qed.

(* dmatrix helper function: construct matrix from list of (column) vectors *)
op ofcols r c (vs : vector list) =
  offunm (fun (i j : int) => (nth witness vs j).[i], r, c).

lemma ofcols_cat k l1 l2 (vs1 vs2 : vector list) :
  size vs1 = l1 => size vs2 = l2 => 
  ofcols k (l1 + l2) (vs1 ++ vs2) = (ofcols k l1 vs1 || ofcols k l2 vs2).
proof.
move => s_vs1 s_vs2.
apply eq_matrixP; rewrite /ofcols size_catmr //=.
rewrite -s_vs1 -s_vs2.
split => [|i j [i_bound j_bound]]; 1:smt(List.size_ge0).
rewrite get_catmr //= ler_maxr; 1: smt(List.size_ge0).
rewrite get_offunm /= ?nth_cat; 1: smt(List.size_ge0).
case (j < size vs1) => j_vs1. 
- rewrite get_offunm. smt(List.size_ge0). rewrite offunm0E. smt(List.size_ge0). 
  by rewrite addr0.
- rewrite offunm0E. smt(List.size_ge0). rewrite get_offunm. smt(List.size_ge0). 
  by rewrite /= s_vs1 add0r.
qed.

(* Distribution of matrices sampled using d element-wise *)
op dmatrix (d : R distr) (r c: int) =
  dmap (dlist (dvector d r) (max 0 c)) (ofcols r c).

lemma size_dmatrix d r c m: m \in dmatrix d r c => 
  0 <= r => 0 <= c => size m = (r, c). 
proof. rewrite supp_dmap /ofcols => -[l [H0 -> /=]] /#. qed.

lemma dmatrix0E (d: R distr) (m: matrix) r c: 
  size m <> (r, c) => 0 <= r => 0 <= c => mu1 (dmatrix d r c) m = 0%r.
proof.
move => size_ineq r_ge0 c_ge0; apply supportPn.
case (m \in dmatrix d r c) => [cont | //].
apply size_ineq.
apply size_dmatrix in cont => /#.
qed.

lemma dmatrix1E d m : mu1 (dmatrix d (rows m) (cols m)) m =
  BRM.bigi predT (fun i => 
    BRM.bigi predT (fun j => mu1 d m.[i, j]) 0 (cols m)) 0 (rows m).
proof.
pose g (m: matrix) := mkseq (fun i => col m i) (cols m).
rewrite (in_dmap1E_can _ _ g) /ofcols.
- rewrite /g eq_matrixP /= => i j bound.
  by rewrite get_offunm /= 1:/# nth_mkseq 1:/# get_col.
- rewrite /g => y y_in <- /=.
  rewrite supp_dlist 1:/# in y_in.
  apply (eq_from_nth witness) => [| i i_bound]; 1: rewrite size_mkseq /#. 
  rewrite nth_mkseq /= 1:/# eq_vectorP size_col rows_offunm.
  rewrite -(all_nthP _ _ witness) in y_in.
  elim y_in => size_y.
  move => cont; have i_bound': 0 <= i && i < size y by done.
  apply cont in i_bound'; clear cont; move: i_bound'.
  move => nth_y_size; apply size_dvector in nth_y_size.
  split => [/# | j j_bound].
  rewrite get_col get_offunm /= /#.  
- rewrite dlist1E 1:/# size_mkseq /= (BRM.big_nth witness) predTofV.
  rewrite BRM.exchange_big /(\o) /= 2!BRM.big_seq /= size_mkseq max0cols. 
  apply BRM.eq_big => i // i_bound /=.
  rewrite mem_range in i_bound.
  have ->: rows m = size (nth witness (g m) i) by rewrite nth_mkseq //.
  rewrite dvector1E.
  congr; apply fun_ext => j.
  by rewrite nth_mkseq.
qed.

lemma dmatrix_uni d r c: is_uniform d => is_uniform (dmatrix d r c).
proof.
move=> uni_d; apply/dmap_uni_in_inj/dlist_uni/dvector_uni => //.
move=> xs ys xsin ysin /= eq_offunm.
apply/(eq_from_nth witness); 1: smt(supp_dlist). 
move => i i_bound; rewrite eq_vectorP. 
rewrite supp_dlist 1:/# -(all_nthP _ _ witness) in xsin.
rewrite supp_dlist 1:/# -(all_nthP _ _ witness) in ysin.
move: xsin ysin => [size_xs in_size_xs] [size_ys in_size_ys].
have ->: size (nth witness xs i) = max 0 r.
- by apply/(size_dvector d)/in_size_xs.
have -> /=: size (nth witness ys i) = max 0 r.
- by apply/(size_dvector d)/in_size_ys => /#.
move => j j_bound; apply (congr1 (fun m => m.[j, i])) in eq_offunm.
rewrite /= get_offunm /= 1:/# in eq_offunm.
rewrite get_offunm /= /# in eq_offunm.
qed.

lemma dmatrix_ll d r c: is_lossless d => is_lossless (dmatrix d r c).
proof. by move=> ll_d; apply/dmap_ll/dlist_ll/dvector_ll. qed.

lemma mu1_dmatrix_fu (d: R distr) (m: matrix): is_funiform d =>
    mu1 (dmatrix d (rows m) (cols m)) m = (mu1 d witness)^((rows m)*(cols m)).
proof.
move => d_funi.
rewrite dmatrix1E.
have ->: (fun (i : int) => (BRM.bigi predT 
           (fun (j : int) => mu1 d m.[i, j]) 0 (cols m))) = 
         (fun (_: int) => (mu1 d witness)^(cols m)).
- apply fun_ext => i.
  have ->: (fun (j : int) => mu1 d m.[i, j]) = fun (_: int) => mu1 d witness.
  + apply fun_ext => j. 
    exact d_funi.
  have: 0 <= cols m by exact cols_ge0. 
  move: (cols m).
  elim/ge0ind => [/# | _ | n bound IH _].
  + by rewrite range_geq //= BRM.big_nil RField.expr0.
  + by rewrite BRM.big_int_recr //= RField.exprS // RField.mulrC IH.
- have: 0 <= rows m by exact rows_ge0. 
  move: (rows m).
  elim/ge0ind => [/# | _ | n bound IH _].
  + by rewrite range_geq //= BRM.big_nil RField.expr0.
  + rewrite BRM.big_int_recr // RField.exprM RField.exprS // RField.mulrC.
    by rewrite RField.exprMn 1:/# /= IH // RField.exprM.
qed.

lemma dmatrix1r d k : 0 <= k => 
  dmatrix d k 1 = dmap (dvector d k) colmx. 
proof.
move => k_ge0; rewrite /dmatrix /ofcols ler_maxr // dlist1 dmap_comp /(\o).
apply eq_dmap_in => v /size_dvector; rewrite ler_maxr //= => s_v.
rewrite /colmx s_v; apply/eq_matrixP => /= i j [i_bound j_bound]. 
by rewrite /ofcols !get_offunm /#.
qed.

lemma supp_dmatrix d m r c : 
  0 <= r => 0 <= c => 
  (m \in dmatrix d r c) <=> 
  size m = (r,c) /\ forall i j, mrange m i j => m.[i,j] \in d.
proof.
move => r_ge0 c_ge0; split => [m_supp|]; last first.
- case => -[r_m c_m] m_d; rewrite /support -r_m -c_m dmatrix1E.
  apply prodr_gt0_seq => i i_row _ /=.
  by apply prodr_gt0_seq => j j_col _ /=; apply m_d; smt(mem_iota).
have [r_m c_m] : size m = (r,c) by smt(size_dmatrix). 
split => [//|i j range_ij]; move: m_supp. 
rewrite -r_m -c_m /support dmatrix1E => gt0_big.
pose G i0 := (fun (j0 : int) => mu1 d m.[i0, j0]).
pose F := fun i0 : int => BRM.bigi predT (G i0) 0 (cols m).
have /(_ i _ _) := gt0_prodr_seq predT F (range 0 (rows m)) _ gt0_big => //.
- by move => a _ _; apply prodr_ge0.
- smt(mem_iota).
move => gt0_F.
have /(_ j _ _) := gt0_prodr_seq predT (G i) (range 0 (cols m)) _ gt0_F => //.
smt(mem_iota). 
qed.

lemma dmatrix_add_r d k l1 l2 : 0 <= k => 0 <= l1 => 0 <= l2 =>
  dmatrix d k (l1 + l2) = 
  dmap (dmatrix d k l1 `*` dmatrix d k l2) 
       (fun mv : matrix * matrix => mv.`1 || mv.`2).
proof.
move => k_ge0 l1_ge0 l2_ge0; rewrite {1}/dmatrix ler_maxr 1:/# dlist_add //.
pose catp (p : vector list * vector list) := p.`1 ++ p.`2.
pose catm (m : matrix * matrix) := m.`1 || m.`2.
pose stichp (p : vector list * vector list) := 
  (ofcols k l1 p.`1,ofcols k l2 p.`2).
rewrite dmap_comp -/catp -/catm (eq_dmap_in _ _ (catm \o stichp)).
- case => vs1 vs2 /supp_dprod /= ?.
  have [? ?] : size vs1 = l1 /\ size vs2 = l2 by smt(supp_dlist).
  exact ofcols_cat.
rewrite -dmap_comp -(dmap_dprod _ _ (ofcols k l1) (ofcols k l2)). 
have {1}-> : l1 = max 0 l1 by smt().
have {1}-> : l2 = max 0 l2 by smt().
done.
qed.

lemma dmatrixSrr d k l : 0 <= k => 0 <= l => 
  dmatrix d k (l + 1) = 
  dmap (dmatrix d k l `*` dvector d k) 
       (fun mv : matrix * vector => mv.`1 || colmx mv.`2).
proof.
move => k_ge0 l_ge0; rewrite dmatrix_add_r // dmatrix1r //.
by rewrite dmap_dprodR dmap_comp.
qed.

lemma dmatrixRSr1E d (m : matrix) k l : 
  0 <= k => 0 <= l => size m = (k,l+1) =>
  mu1 (dmatrix d k (l + 1)) m =
  mu1 (dmatrix d k l `*` dvector d k) (subm m 0 k 0 l, col m l).
proof.
move => k_ge0 l_ge0 [r_m c_m]; rewrite dmatrixSrr // dmapE /(\o) /pred1. 
apply mu_eq_support => -[m2 v] /supp_dprod /=.
case => /size_dmatrix /(_ _ _) //= [r_m2 c_m2] /size_dvector s_v.
apply/eq_iff; split => [<-|[-> ->]]; last by rewrite -r_m subm_colmx.
rewrite -r_m2 -c_m2 subm_catmrCl /=. 
by rewrite col_catmrR //= /#.
qed.

lemma supp_dmatrix_full m d r c : 
  0 <= r => 0 <= c =>
  is_full d => m \in dmatrix d r c <=> size m = (r,c).
proof. smt(supp_dmatrix). qed.

lemma dvector_rnd_funi (d : R distr) (v1 v2 : vector) l :
  is_funiform d => size v1 = size v2 => 
  mu1 (dvector d l) v1 = mu1 (dvector d l) v2.
proof. 
move=> d_funi s_v1_v2; case (l = size v1) => [->|?]. 
- by rewrite mu1_dvector_fu // s_v1_v2 mu1_dvector_fu.
smt(get_dvector0E emptyv_unique).
qed.

lemma supp_dmatrix_catmr d (m1 m2 : matrix) r c1 c2 : 
  0 <= r => 0 <= c1 => 0 <= c2 =>
  m1 \in dmatrix d r c1 => m2 \in dmatrix d r c2 => 
  (m1 || m2) \in dmatrix d r (c1 + c2).
proof.
move => r_ge0 c1_ge0 c2_ge0 m1_dm m2_dm; rewrite dmatrix_add_r //.
rewrite (dmap_dprod_comp _ _ (fun x => x) (fun x => x) catmr).
rewrite /support !dmap_id dmapE /(\o) /pred1. 
apply (StdOrder.RealOrder.ltr_le_trans 
      (mu1 (dmatrix d r c1 `*` dmatrix d r c2) (m1,m2))).
- rewrite witness_support.
  exists (m1, m2).
  by rewrite supp_dprod /= m1_dm m2_dm.
- by apply mu_le => /#.
qed.

end Matrices.

export Matrices.
back to top