https://github.com/project-everest/hacl-star
Raw File
Tip revision: 979490eae06f54ec3a4768cdefdef741b6aa97b4 authored by Ann Weine on 23 September 2021, 10:43:48 UTC
Radix/Ml separation
Tip revision: 979490e
Hacl.Impl.Matrix.fst
module Hacl.Impl.Matrix

open FStar.HyperStack.ST
open FStar.Mul

open LowStar.BufferOps
open LowStar.Buffer

open Lib.IntTypes
open Lib.Buffer
open Lib.ByteBuffer

module HS  = FStar.HyperStack
module ST  = FStar.HyperStack.ST
module LSeq = Lib.Sequence
module BSeq = Lib.ByteSequence
module Loops = Lib.LoopCombinators

module M   = Spec.Matrix
module Lemmas = Spec.Frodo.Lemmas

#set-options "--z3rlimit 50 --fuel 0 --ifuel 0"

unfold
let elem = uint16

unfold
let lbytes len = lbuffer uint8 len

unfold
let matrix_t (n1:size_t) (n2:size_t{v n1 * v n2 <= max_size_t}) =
  lbuffer elem (n1 *! n2)

unfold
let as_matrix #n1 #n2 h (m:matrix_t n1 n2) : GTot (M.matrix (v n1) (v n2)) =
  as_seq h m


inline_for_extraction noextract
val matrix_create:
    n1:size_t
  -> n2:size_t{0 < v n1 * v n2 /\ v n1 * v n2 <= max_size_t}
  -> StackInline (matrix_t n1 n2)
    (requires fun h0 -> True)
    (ensures  fun h0 a h1 ->
      stack_allocated a h0 h1 (as_matrix h1 a) /\
      as_matrix h1 a == M.create (v n1) (v n2))

let matrix_create n1 n2 =
  [@inline_let]
  let len = size (normalize_term (v n1 * v n2)) in
  create len (u16 0)


inline_for_extraction noextract
val mget:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> a:matrix_t n1 n2
  -> i:size_t{v i < v n1}
  -> j:size_t{v j < v n2}
  -> Stack elem
    (requires fun h0 -> live h0 a)
    (ensures  fun h0 x h1 -> modifies loc_none h0 h1 /\
      x == M.mget (as_matrix h0 a) (v i) (v j))

let mget #n1 #n2 a i j =
  M.index_lt (v n1) (v n2) (v i) (v j);
  a.(i *! n2 +! j)


inline_for_extraction noextract
val mset:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> a:matrix_t n1 n2
  -> i:size_t{v i < v n1}
  -> j:size_t{v j < v n2}
  -> x:elem
  -> Stack unit
    (requires fun h0 -> live h0 a)
    (ensures  fun h0 _ h1 -> modifies1 a h0 h1 /\
      as_matrix h1 a == M.mset (as_matrix h0 a) (v i) (v j) x)

let mset #n1 #n2 a i j x =
  M.index_lt (v n1) (v n2) (v i) (v j);
  a.(i *! n2 +! j) <- x


noextract unfold
val op_String_Access (#n1:size_t) (#n2:size_t{v n1 * v n2 <= max_size_t}) (m:matrix_t n1 n2) (ij:(size_t & size_t){let i, j = ij in v i < v n1 /\ v j < v n2})
  : Stack elem
    (requires fun h0 -> live h0 m)
    (ensures  fun h0 x h1 -> let i, j = ij in modifies loc_none h0 h1 /\ x == M.mget (as_matrix h0 m) (v i) (v j))
let op_String_Access #n1 #n2 m (i,j) = mget m i j


noextract unfold
val op_String_Assignment (#n1:size_t) (#n2:size_t{v n1 * v n2 <= max_size_t}) (m:matrix_t n1 n2) (ij:(size_t & size_t){let i, j = ij in v i < v n1 /\ v j < v n2}) (x:elem)
  : Stack unit
    (requires fun h0 -> live h0 m)
    (ensures  fun h0 _ h1 -> let i, j = ij in modifies1 m h0 h1 /\ live h1 m /\ as_matrix h1 m == M.mset (as_matrix h0 m) (v i) (v j) x)
let op_String_Assignment #n1 #n2 m (i,j) x = mset m i j x


unfold
let get #n1 #n2 h (m:matrix_t n1 n2) i j = M.mget (as_matrix h m) i j

private unfold
val map_inner_inv:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> h2:HS.mem
  -> f:(elem -> elem)
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> i:size_t{v i < v n1}
  -> j:size_nat
  -> Type0

let map_inner_inv #n1 #n2 h0 h1 h2 f a c i j =
  live h2 a /\ live h2 c /\
  modifies1 c h1 h2 /\
  j <= v n2 /\
  (forall (i0:nat{i0 < v i}) (j:nat{j < v n2}). get h2 c i0 j == get h1 c i0 j) /\
  (forall (j0:nat{j0 < j}). get h2 c (v i) j0 == f (get h0 a (v i) j0)) /\
  (forall (j0:nat{j <= j0 /\ j0 < v n2}). get h2 c (v i) j0 == get h0 c (v i) j0) /\
  (forall (i0:nat{v i < i0 /\ i0 < v n1}) (j:nat{j < v n2}). get h2 c i0 j == get h0 c i0 j)


inline_for_extraction noextract private
val map_inner:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> f:(elem -> elem)
  -> a:matrix_t n1 n2
  -> c:matrix_t n1 n2{a == c}
  -> i:size_t{v i < v n1}
  -> j:size_t{v j < v n2}
  -> Stack unit
    (requires fun h2 -> map_inner_inv h0 h1 h2 f a c i (v j))
    (ensures  fun _ _ h2 -> map_inner_inv h0 h1 h2 f a c i (v j + 1))

let map_inner #n1 #n2 h0 h1 f a c i j =
  c.[i,j] <- f a.[i,j]


inline_for_extraction
val map:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> f:(uint16 -> uint16)
  -> a:matrix_t n1 n2
  -> c:matrix_t n1 n2
  -> Stack unit
    (requires fun h0 ->
      live h0 a /\ live h0 c /\ a == c)
    (ensures  fun h0 _ h1 -> modifies1 c h0 h1 /\
      as_matrix h1 c == M.map #(v n1) #(v n2) f (as_matrix h0 a))

let map #n1 #n2 f a c =
  let h0 = ST.get () in
  Lib.Loops.for (size 0) n1
    (fun h1 i -> live h1 a /\ live h1 c /\
      modifies1 c h0 h1 /\ i <= v n1 /\
      (forall (i0:nat{i0 < i}) (j:nat{j < v n2}).
        get h1 c i0 j == f (get h0 a i0 j)) /\
      (forall (i0:nat{i <= i0 /\ i0 < v n1}) (j:nat{j < v n2}).
        get h1 c i0 j == get h0 c i0 j) )
    (fun i ->
      let h1 = ST.get() in
      Lib.Loops.for (size 0) n2
        (fun h2 j -> map_inner_inv h0 h1 h2 f a c i j)
        (fun j -> map_inner h0 h1 f a c i j)
    );
    let h2 = ST.get () in
    M.extensionality (as_matrix h2 c) (M.map f (as_matrix h0 a))


inline_for_extraction noextract
val mod_pow2_felem:
    logq:size_t{0 < v logq /\ v logq < 16}
  -> a:uint16
  -> Pure uint16
    (requires True)
    (ensures  fun r -> r == M.mod_pow2_felem (v logq) a)

let mod_pow2_felem logq a =
  a &. ((u16 1 <<. logq) -. u16 1)


val mod_pow2:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> logq:size_t{0 < v logq /\ v logq <= 16}
  -> a:matrix_t n1 n2
  -> Stack unit
    (requires fun h -> live h a)
    (ensures  fun h0 _ h1 -> modifies1 a h0 h1 /\
      as_matrix h1 a == M.mod_pow2 (v logq) (as_matrix h0 a))

[@"c_inline"]
let mod_pow2 #n1 #n2 logq a =
  if logq <. 16ul then begin
    let h0 = ST.get () in
    map (mod_pow2_felem logq) a a;
    M.extensionality
      (M.map #(v n1) #(v n2) (mod_pow2_felem logq) (as_matrix h0 a))
      (M.map #(v n1) #(v n2) (M.mod_pow2_felem (v logq)) (as_matrix h0 a)) end


private unfold
val map2_inner_inv:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> h2:HS.mem
  -> f:(elem -> elem -> elem)
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> c:matrix_t n1 n2
  -> i:size_t{v i < v n1}
  -> j:size_nat
  -> Type0

let map2_inner_inv #n1 #n2 h0 h1 h2 f a b c i j =
  live h2 a /\ live h2 b /\ live h2 c /\
  modifies1 c h1 h2 /\
  j <= v n2 /\
  (forall (i0:nat{i0 < v i}) (j:nat{j < v n2}). get h2 c i0 j == get h1 c i0 j) /\
  (forall (j0:nat{j0 < j}). get h2 c (v i) j0 == f (get h0 a (v i) j0) (get h2 b (v i) j0)) /\
  (forall (j0:nat{j <= j0 /\ j0 < v n2}). get h2 c (v i) j0 == get h0 c (v i) j0) /\
  (forall (i0:nat{v i < i0 /\ i0 < v n1}) (j:nat{j < v n2}). get h2 c i0 j == get h0 c i0 j)


inline_for_extraction noextract private
val map2_inner:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> f:(elem -> elem -> elem)
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> c:matrix_t n1 n2{a == c /\ disjoint b c}
  -> i:size_t{v i < v n1}
  -> j:size_t{v j < v n2}
  -> Stack unit
    (requires fun h2 -> map2_inner_inv h0 h1 h2 f a b c i (v j))
    (ensures  fun _ _ h2 -> map2_inner_inv h0 h1 h2 f a b c i (v j + 1))

let map2_inner #n1 #n2 h0 h1 f a b c i j =
  c.[i,j] <- f a.[i,j] b.[i,j]


/// In-place [map2], a == map2 f a b
///
/// A non in-place variant can be obtained by weakening the pre-condition to disjoint a c,
/// or the two variants can be merged by requiring (a == c \/ disjoint a c) instead of a == c
/// See commit 91916b8372fa3522061eff5a42d0ebd1d19a8a49
inline_for_extraction
val map2:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> f:(uint16 -> uint16 -> uint16)
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> c:matrix_t n1 n2
  -> Stack unit
    (requires fun h0 ->
      live h0 a /\ live h0 b /\ live h0 c /\
      a == c /\ disjoint b c)
    (ensures  fun h0 _ h1 -> modifies1 c h0 h1 /\
      as_matrix h1 c == M.map2 #(v n1) #(v n2) f (as_matrix h0 a) (as_matrix h0 b))

let map2 #n1 #n2 f a b c =
  let h0 = ST.get () in
  Lib.Loops.for (size 0) n1
    (fun h1 i -> live h1 a /\ live h1 b /\ live h1 c /\
      modifies1 c h0 h1 /\ i <= v n1 /\
      (forall (i0:nat{i0 < i}) (j:nat{j < v n2}).
        get h1 c i0 j == f (get h0 a i0 j) (get h0 b i0 j)) /\
      (forall (i0:nat{i <= i0 /\ i0 < v n1}) (j:nat{j < v n2}).
        get h1 c i0 j == get h0 c i0 j) )
    (fun i ->
      let h1 = ST.get() in
      Lib.Loops.for (size 0) n2
        (fun h2 j -> map2_inner_inv h0 h1 h2 f a b c i j)
        (fun j -> map2_inner h0 h1 f a b c i j)
    );
    let h2 = ST.get() in
    M.extensionality (as_matrix h2 c) (M.map2 f (as_matrix h0 a) (as_matrix h0 b))


val matrix_add:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> Stack unit
    (requires fun h -> live h a /\ live h b /\ disjoint a b)
    (ensures  fun h0 r h1 -> modifies1 a h0 h1 /\
      as_matrix h1 a == M.add (as_matrix h0 a) (as_matrix h0 b))

[@"c_inline"]
let matrix_add #n1 #n2 a b =
  map2 add_mod a b a


val matrix_sub:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> Stack unit
    (requires fun h -> live h a /\ live h b /\ disjoint a b)
    (ensures  fun h0 r h1 -> modifies1 b h0 h1 /\
      as_matrix h1 b == M.sub (as_matrix h0 a) (as_matrix h0 b))

[@"c_inline"]
let matrix_sub #n1 #n2 a b =
  (* Use the in-place variant above by flipping the arguments of [sub_mod] *)
  (* Requires appplying extensionality *)
  let h0 = ST.get() in
  [@ inline_let ]
  let sub_mod_flipped x y = sub_mod y x in
  map2 sub_mod_flipped b a b;
  let h1 = ST.get() in
  M.extensionality (as_matrix h1 b) (M.sub (as_matrix h0 a) (as_matrix h0 b))


#push-options "--fuel 1"
inline_for_extraction noextract private
val mul_inner:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> i:size_t{v i < v n1}
  -> k:size_t{v k < v n3}
  -> Stack uint16
    (requires fun h -> live h a /\ live h b)
    (ensures  fun h0 r h1 -> modifies loc_none h0 h1 /\
      r == M.mul_inner (as_matrix h0 a) (as_matrix h0 b) (v i) (v k))

let mul_inner #n1 #n2 #n3 a b i k =
  push_frame();
  let h0 = ST.get() in
  [@ inline_let ]
  let f l = get h0 a (v i) l *. get h0 b l (v k) in
  let res = create #uint16 (size 1) (u16 0) in

  let h1 = ST.get() in
  Lib.Loops.for (size 0) n2
    (fun h2 j -> live h1 res /\ live h2 res /\
      modifies1 res h1 h2 /\
      bget h2 res 0 == M.sum_ #(v n2) f j)
    (fun j ->
      let aij = a.[i,j] in
      let bjk = b.[j,k] in
      let res0 = res.(size 0) in
      res.(size 0) <- res0 +. aij *. bjk
    );
  let res = res.(size 0) in
  M.sum_extensionality (v n2) f (fun l -> get h0 a (v i) l *. get h0 b l (v k)) (v n2);
  assert (res == M.mul_inner (as_matrix h0 a) (as_matrix h0 b) (v i) (v k));
  pop_frame();
  res
#pop-options


private unfold
val mul_inner_inv:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> h2:HS.mem
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> c:matrix_t n1 n3
  -> f:(k:nat{k < v n3} -> GTot uint16)
  -> i:size_t{v i < v n1}
  -> k:size_nat
  -> Type0

let mul_inner_inv #n1 #n2 #n3 h0 h1 h2 a b c f i k =
  live h2 a /\ live h2 b /\ live h2 c /\
  modifies1 c h1 h2 /\
  k <= v n3 /\
  (forall (i1:nat{i1 < v i}) (k:nat{k < v n3}). get h2 c i1 k == get h1 c i1 k) /\
  (forall (k1:nat{k1 < k}). get h2 c (v i) k1 == f k1) /\
  (forall (k1:nat{k <= k1 /\ k1 < v n3}). get h2 c (v i) k1 == get h0 c (v i) k1) /\
  (forall (i1:nat{v i < i1 /\ i1 < v n1}) (k:nat{k < v n3}). get h2 c i1 k == get h0 c i1 k) /\
  as_matrix h0 a == as_matrix h2 a /\
  as_matrix h0 b == as_matrix h2 b


inline_for_extraction noextract private
val mul_inner1:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> c:matrix_t n1 n3{disjoint a c /\ disjoint b c}
  -> i:size_t{v i < v n1}
  -> k:size_t{v k < v n3}
  -> f:(k:nat{k < v n3}
       -> GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a (v i) l *. get h0 b l k)}))
  -> Stack unit
    (requires fun h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k))
    (ensures  fun _ _ h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k + 1))

let mul_inner1 #n1 #n2 #n3 h0 h1 a b c i k f =
  assert (M.mul_inner (as_matrix h0 a) (as_matrix h0 b) (v i) (v k) ==
          M.sum #(v n2) (fun l -> get h0 a (v i) l *. get h0 b l (v k)));
  c.[i,k] <- mul_inner a b i k;
  let h2 = ST.get () in
  assert (get h2 c (v i) (v k) == f (v k))


private
val onemore: p:(nat -> Type0) -> q:(i:nat{p i} -> Type0) -> b:nat{p b} -> Lemma
  (requires (forall (i:nat{p i /\ i < b}). q i) /\ q b)
  (ensures  forall (i:nat{p i /\ i <= b}). q i)
let onemore p q b = ()


val onemore1:
    #n1:size_nat
  -> #n3:size_nat{n1 * n3 <= max_size_t}
  -> c:M.matrix n1 n3
  -> f:(i:nat{i < n1} -> k:nat{k < n3} -> GTot uint16)
  -> i:size_nat{i < n1} -> Lemma
  (requires ((forall (i1:nat{i1 < i}) (k:nat{k < n3}). M.mget c i1 k == f i1 k) /\ (forall (k:nat{k < n3}). M.mget c i k == f i k)))
  (ensures (forall (i1:nat{i1 <= i}) (k:nat{k < n3}). M.mget c i1 k == f i1 k))

let onemore1 #n1 #n3 c f i = ()


val matrix_mul:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> c:matrix_t n1 n3
  -> Stack unit
    (requires fun h ->
      live h a /\ live h b /\ live h c /\
      disjoint a c /\ disjoint b c)
    (ensures  fun h0 _ h1 -> modifies1 c h0 h1 /\
      as_matrix h1 c == M.mul (as_matrix h0 a) (as_matrix h0 b))

[@"c_inline"]
let matrix_mul #n1 #n2 #n3 a b c =
  let h0 = ST.get () in
  let f (i:nat{i < v n1}) (k:nat{k < v n3}) :
    GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a i l *. get h0 b l k)})
  = M.sum #(v n2) (fun l -> get h0 a i l *. get h0 b l k)
  in
  Lib.Loops.for (size 0) n1
    (fun h1 i ->
      live h1 a /\ live h1 b /\ live h1 c /\
      modifies1 c h0 h1 /\ i <= v n1 /\
      (forall (i1:nat{i1 < i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k) /\
      (forall (i1:nat{i <= i1 /\ i1 < v n1}) (k:nat{k < v n3}). get h1 c i1 k == get h0 c i1 k))
    (fun i ->
      let h1 = ST.get() in
      Lib.Loops.for (size 0) n3
        (fun h2 k -> mul_inner_inv h0 h1 h2 a b c (f (v i)) i k)
        (fun k -> mul_inner1 h0 h1 a b c i k (f (v i)));
      let h1 = ST.get() in
      let q i1 = forall (k:nat{k < v n3}). get h1 c i1 k == f i1 k in
      onemore (fun i1 -> i1 < v n1) q (v i);
      assert (forall (i1:nat{i1 < v n1 /\ i1 <= v i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k)
    );
  let h2 = ST.get() in
  M.extensionality (as_matrix h2 c) (M.mul (as_matrix h0 a) (as_matrix h0 b))

(* Special case of matrix multiplication *)
(* when we have a different way of accessing to entries of the matrix S *)

inline_for_extraction noextract
val mget_s:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> a:matrix_t n1 n2
  -> i:size_t{v i < v n1}
  -> j:size_t{v j < v n2}
  -> Stack elem
    (requires fun h0 -> live h0 a)
    (ensures  fun h0 x h1 -> modifies0 h0 h1 /\
      x == M.mget_s (as_matrix h0 a) (v i) (v j))

let mget_s #n1 #n2 a i j =
  M.index_lt (v n2) (v n1) (v j) (v i);
  a.(j *! n1 +! i)

unfold
let get_s #n1 #n2 h (m:matrix_t n1 n2) i j = M.mget_s (as_matrix h m) i j


#push-options "--fuel 1"
inline_for_extraction noextract private
val mul_inner_s:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> i:size_t{v i < v n1}
  -> k:size_t{v k < v n3}
  -> Stack uint16
    (requires fun h -> live h a /\ live h b)
    (ensures  fun h0 r h1 -> modifies0 h0 h1 /\
      r == M.mul_inner_s (as_matrix h0 a) (as_matrix h0 b) (v i) (v k))

let mul_inner_s #n1 #n2 #n3 a b i k =
  push_frame();
  let h0 = ST.get() in
  [@ inline_let ]
  let f l = get h0 a (v i) l *. get_s h0 b l (v k) in
  let res = create #uint16 (size 1) (u16 0) in
  let h1 = ST.get() in
  Lib.Loops.for (size 0) n2
    (fun h2 j -> live h1 res /\ live h2 res /\
      modifies1 res h1 h2 /\
      bget h2 res 0 == M.sum_ #(v n2) f j)
    (fun j ->
      let aij = mget a i j in
      let bjk = mget_s b j k in
      let res0 = res.(size 0) in
      res.(size 0) <- res0 +. aij *. bjk
    );
  let res = res.(size 0) in
  M.sum_extensionality (v n2) f (fun l -> get h0 a (v i) l *. get_s h0 b l (v k)) (v n2);
  assert (res == M.mul_inner_s (as_matrix h0 a) (as_matrix h0 b) (v i) (v k));
  pop_frame();
  res
#pop-options


inline_for_extraction noextract private
val mul_inner1_s:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> h0:HS.mem
  -> h1:HS.mem
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> c:matrix_t n1 n3{disjoint a c /\ disjoint b c}
  -> i:size_t{v i < v n1}
  -> k:size_t{v k < v n3}
  -> f:(k:nat{k < v n3}
       -> GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a (v i) l *. get_s h0 b l k)}))
  -> Stack unit
    (requires fun h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k))
    (ensures  fun _ _ h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k + 1))

let mul_inner1_s #n1 #n2 #n3 h0 h1 a b c i k f =
  assert (M.mul_inner_s (as_matrix h0 a) (as_matrix h0 b) (v i) (v k) ==
          M.sum #(v n2) (fun l -> get h0 a (v i) l *. get_s h0 b l (v k)));
  mset c i k (mul_inner_s a b i k);
  let h2 = ST.get () in
  assert (get h2 c (v i) (v k) == f (v k))


val matrix_mul_s:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n2 n3
  -> c:matrix_t n1 n3
  -> Stack unit
    (requires fun h ->
      live h a /\ live h b /\ live h c /\
      disjoint a c /\ disjoint b c)
    (ensures  fun h0 _ h1 -> modifies1 c h0 h1 /\
      as_matrix h1 c == M.mul_s (as_matrix h0 a) (as_matrix h0 b))

[@"c_inline"]
let matrix_mul_s #n1 #n2 #n3 a b c =
  let h0 = ST.get () in
  let f (i:nat{i < v n1}) (k:nat{k < v n3}) :
    GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a i l *. get_s h0 b l k)})
  = M.sum #(v n2) (fun l -> get h0 a i l *. get_s h0 b l k)
  in
  Lib.Loops.for (size 0) n1
    (fun h1 i ->
      live h1 a /\ live h1 b /\ live h1 c /\
      modifies1 c h0 h1 /\ i <= v n1 /\
      (forall (i1:nat{i1 < i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k) /\
      (forall (i1:nat{i <= i1 /\ i1 < v n1}) (k:nat{k < v n3}). get h1 c i1 k == get h0 c i1 k))
    (fun i ->
      let h1 = ST.get() in
      Lib.Loops.for (size 0) n3
        (fun h2 k -> mul_inner_inv h0 h1 h2 a b c (f (v i)) i k)
        (fun k -> mul_inner1_s h0 h1 a b c i k (f (v i)));
      let h1 = ST.get() in
      let q i1 = forall k. get h1 c i1 k == f i1 k in
      onemore (fun i1 -> i1 < v n1) q (v i);
      assert (forall (i1:nat{i1 < v n1 /\ i1 <= v i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k)
    );
  let h2 = ST.get() in
  M.extensionality (as_matrix h2 c) (M.mul_s (as_matrix h0 a) (as_matrix h0 b))

(* the end of the special matrix multiplication *)

val matrix_eq:
    #n1:size_t
  -> #n2:size_t{v n1 * v n2 <= max_size_t}
  -> a:matrix_t n1 n2
  -> b:matrix_t n1 n2
  -> Stack uint16
    (requires fun h -> live h a /\ live h b)
    (ensures  fun h0 r h1 -> modifies0 h0 h1 /\
      r == M.matrix_eq #(v n1) #(v n2) (as_matrix h0 a) (as_matrix h0 b))

[@"c_inline"]
let matrix_eq #n1 #n2 a b =
  push_frame();
  let res = create 1ul (ones U16 SEC) in
  let r = buf_eq_mask a b (n1 *! n2) res in
  pop_frame ();
  r


val matrix_to_lbytes:
    #n1:size_t
  -> #n2:size_t{2 * v n1 <= max_size_t /\ 2 * v n1 * v n2 <= max_size_t}
  -> m:matrix_t n1 n2
  -> res:lbytes (2ul *! n1 *! n2)
  -> Stack unit
    (requires fun h -> live h m /\ live h res /\ disjoint m res)
    (ensures  fun h0 r h1 -> modifies1 res h0 h1 /\
      as_seq h1 res == M.matrix_to_lbytes #(v n1) #(v n2) (as_matrix h0 m))

[@"c_inline"]
let matrix_to_lbytes #n1 #n2 m res =
  let h0 = ST.get () in
  fill_blocks_simple h0 2ul (n1 *! n2) res
  (fun h -> M.matrix_to_lbytes_f #(v n1) #(v n2) (as_matrix h0 m))
  (fun i -> uint_to_bytes_le (sub res (2ul *! i) 2ul) m.(i))


val matrix_from_lbytes:
    #n1:size_t
  -> #n2:size_t{2 * v n1 <= max_size_t /\ 2 * v n1 * v n2 <= max_size_t}
  -> b:lbytes (size 2 *! n1 *! n2)
  -> res:matrix_t n1 n2
  -> Stack unit
    (requires fun h -> live h b /\ live h res /\ disjoint b res)
    (ensures  fun h0 _ h1 -> modifies1 res h0 h1 /\
      as_matrix h1 res == M.matrix_from_lbytes (v n1) (v n2) (as_seq h0 b))

[@"c_inline"]
let matrix_from_lbytes #n1 #n2 b res =
  let h0 = ST.get () in
  assert (v n1 * v n2 <= max_size_t);
  fill h0 (n1 *! n2) res
  (fun h -> M.matrix_from_lbytes_f (v n1) (v n2) (as_seq h0 b))
  (fun i ->
    assert (2 * v i + 2 <= 2 * v n1 * v n2);
    uint_from_bytes_le #U16 (sub b (2ul *! i) 2ul));
  let h1 = ST.get () in
  M.extensionality (as_matrix h1 res) (M.matrix_from_lbytes (v n1) (v n2) (as_seq h0 b))
back to top