https://github.com/project-everest/hacl-star
Tip revision: 979490eae06f54ec3a4768cdefdef741b6aa97b4 authored by Ann Weine on 23 September 2021, 10:43:48 UTC
Radix/Ml separation
Radix/Ml separation
Tip revision: 979490e
Hacl.Impl.Matrix.fst
module Hacl.Impl.Matrix
open FStar.HyperStack.ST
open FStar.Mul
open LowStar.BufferOps
open LowStar.Buffer
open Lib.IntTypes
open Lib.Buffer
open Lib.ByteBuffer
module HS = FStar.HyperStack
module ST = FStar.HyperStack.ST
module LSeq = Lib.Sequence
module BSeq = Lib.ByteSequence
module Loops = Lib.LoopCombinators
module M = Spec.Matrix
module Lemmas = Spec.Frodo.Lemmas
#set-options "--z3rlimit 50 --fuel 0 --ifuel 0"
unfold
let elem = uint16
unfold
let lbytes len = lbuffer uint8 len
unfold
let matrix_t (n1:size_t) (n2:size_t{v n1 * v n2 <= max_size_t}) =
lbuffer elem (n1 *! n2)
unfold
let as_matrix #n1 #n2 h (m:matrix_t n1 n2) : GTot (M.matrix (v n1) (v n2)) =
as_seq h m
inline_for_extraction noextract
val matrix_create:
n1:size_t
-> n2:size_t{0 < v n1 * v n2 /\ v n1 * v n2 <= max_size_t}
-> StackInline (matrix_t n1 n2)
(requires fun h0 -> True)
(ensures fun h0 a h1 ->
stack_allocated a h0 h1 (as_matrix h1 a) /\
as_matrix h1 a == M.create (v n1) (v n2))
let matrix_create n1 n2 =
[@inline_let]
let len = size (normalize_term (v n1 * v n2)) in
create len (u16 0)
inline_for_extraction noextract
val mget:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> a:matrix_t n1 n2
-> i:size_t{v i < v n1}
-> j:size_t{v j < v n2}
-> Stack elem
(requires fun h0 -> live h0 a)
(ensures fun h0 x h1 -> modifies loc_none h0 h1 /\
x == M.mget (as_matrix h0 a) (v i) (v j))
let mget #n1 #n2 a i j =
M.index_lt (v n1) (v n2) (v i) (v j);
a.(i *! n2 +! j)
inline_for_extraction noextract
val mset:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> a:matrix_t n1 n2
-> i:size_t{v i < v n1}
-> j:size_t{v j < v n2}
-> x:elem
-> Stack unit
(requires fun h0 -> live h0 a)
(ensures fun h0 _ h1 -> modifies1 a h0 h1 /\
as_matrix h1 a == M.mset (as_matrix h0 a) (v i) (v j) x)
let mset #n1 #n2 a i j x =
M.index_lt (v n1) (v n2) (v i) (v j);
a.(i *! n2 +! j) <- x
noextract unfold
val op_String_Access (#n1:size_t) (#n2:size_t{v n1 * v n2 <= max_size_t}) (m:matrix_t n1 n2) (ij:(size_t & size_t){let i, j = ij in v i < v n1 /\ v j < v n2})
: Stack elem
(requires fun h0 -> live h0 m)
(ensures fun h0 x h1 -> let i, j = ij in modifies loc_none h0 h1 /\ x == M.mget (as_matrix h0 m) (v i) (v j))
let op_String_Access #n1 #n2 m (i,j) = mget m i j
noextract unfold
val op_String_Assignment (#n1:size_t) (#n2:size_t{v n1 * v n2 <= max_size_t}) (m:matrix_t n1 n2) (ij:(size_t & size_t){let i, j = ij in v i < v n1 /\ v j < v n2}) (x:elem)
: Stack unit
(requires fun h0 -> live h0 m)
(ensures fun h0 _ h1 -> let i, j = ij in modifies1 m h0 h1 /\ live h1 m /\ as_matrix h1 m == M.mset (as_matrix h0 m) (v i) (v j) x)
let op_String_Assignment #n1 #n2 m (i,j) x = mset m i j x
unfold
let get #n1 #n2 h (m:matrix_t n1 n2) i j = M.mget (as_matrix h m) i j
private unfold
val map_inner_inv:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> h2:HS.mem
-> f:(elem -> elem)
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> i:size_t{v i < v n1}
-> j:size_nat
-> Type0
let map_inner_inv #n1 #n2 h0 h1 h2 f a c i j =
live h2 a /\ live h2 c /\
modifies1 c h1 h2 /\
j <= v n2 /\
(forall (i0:nat{i0 < v i}) (j:nat{j < v n2}). get h2 c i0 j == get h1 c i0 j) /\
(forall (j0:nat{j0 < j}). get h2 c (v i) j0 == f (get h0 a (v i) j0)) /\
(forall (j0:nat{j <= j0 /\ j0 < v n2}). get h2 c (v i) j0 == get h0 c (v i) j0) /\
(forall (i0:nat{v i < i0 /\ i0 < v n1}) (j:nat{j < v n2}). get h2 c i0 j == get h0 c i0 j)
inline_for_extraction noextract private
val map_inner:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> f:(elem -> elem)
-> a:matrix_t n1 n2
-> c:matrix_t n1 n2{a == c}
-> i:size_t{v i < v n1}
-> j:size_t{v j < v n2}
-> Stack unit
(requires fun h2 -> map_inner_inv h0 h1 h2 f a c i (v j))
(ensures fun _ _ h2 -> map_inner_inv h0 h1 h2 f a c i (v j + 1))
let map_inner #n1 #n2 h0 h1 f a c i j =
c.[i,j] <- f a.[i,j]
inline_for_extraction
val map:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> f:(uint16 -> uint16)
-> a:matrix_t n1 n2
-> c:matrix_t n1 n2
-> Stack unit
(requires fun h0 ->
live h0 a /\ live h0 c /\ a == c)
(ensures fun h0 _ h1 -> modifies1 c h0 h1 /\
as_matrix h1 c == M.map #(v n1) #(v n2) f (as_matrix h0 a))
let map #n1 #n2 f a c =
let h0 = ST.get () in
Lib.Loops.for (size 0) n1
(fun h1 i -> live h1 a /\ live h1 c /\
modifies1 c h0 h1 /\ i <= v n1 /\
(forall (i0:nat{i0 < i}) (j:nat{j < v n2}).
get h1 c i0 j == f (get h0 a i0 j)) /\
(forall (i0:nat{i <= i0 /\ i0 < v n1}) (j:nat{j < v n2}).
get h1 c i0 j == get h0 c i0 j) )
(fun i ->
let h1 = ST.get() in
Lib.Loops.for (size 0) n2
(fun h2 j -> map_inner_inv h0 h1 h2 f a c i j)
(fun j -> map_inner h0 h1 f a c i j)
);
let h2 = ST.get () in
M.extensionality (as_matrix h2 c) (M.map f (as_matrix h0 a))
inline_for_extraction noextract
val mod_pow2_felem:
logq:size_t{0 < v logq /\ v logq < 16}
-> a:uint16
-> Pure uint16
(requires True)
(ensures fun r -> r == M.mod_pow2_felem (v logq) a)
let mod_pow2_felem logq a =
a &. ((u16 1 <<. logq) -. u16 1)
val mod_pow2:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> logq:size_t{0 < v logq /\ v logq <= 16}
-> a:matrix_t n1 n2
-> Stack unit
(requires fun h -> live h a)
(ensures fun h0 _ h1 -> modifies1 a h0 h1 /\
as_matrix h1 a == M.mod_pow2 (v logq) (as_matrix h0 a))
[@"c_inline"]
let mod_pow2 #n1 #n2 logq a =
if logq <. 16ul then begin
let h0 = ST.get () in
map (mod_pow2_felem logq) a a;
M.extensionality
(M.map #(v n1) #(v n2) (mod_pow2_felem logq) (as_matrix h0 a))
(M.map #(v n1) #(v n2) (M.mod_pow2_felem (v logq)) (as_matrix h0 a)) end
private unfold
val map2_inner_inv:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> h2:HS.mem
-> f:(elem -> elem -> elem)
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> c:matrix_t n1 n2
-> i:size_t{v i < v n1}
-> j:size_nat
-> Type0
let map2_inner_inv #n1 #n2 h0 h1 h2 f a b c i j =
live h2 a /\ live h2 b /\ live h2 c /\
modifies1 c h1 h2 /\
j <= v n2 /\
(forall (i0:nat{i0 < v i}) (j:nat{j < v n2}). get h2 c i0 j == get h1 c i0 j) /\
(forall (j0:nat{j0 < j}). get h2 c (v i) j0 == f (get h0 a (v i) j0) (get h2 b (v i) j0)) /\
(forall (j0:nat{j <= j0 /\ j0 < v n2}). get h2 c (v i) j0 == get h0 c (v i) j0) /\
(forall (i0:nat{v i < i0 /\ i0 < v n1}) (j:nat{j < v n2}). get h2 c i0 j == get h0 c i0 j)
inline_for_extraction noextract private
val map2_inner:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> f:(elem -> elem -> elem)
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> c:matrix_t n1 n2{a == c /\ disjoint b c}
-> i:size_t{v i < v n1}
-> j:size_t{v j < v n2}
-> Stack unit
(requires fun h2 -> map2_inner_inv h0 h1 h2 f a b c i (v j))
(ensures fun _ _ h2 -> map2_inner_inv h0 h1 h2 f a b c i (v j + 1))
let map2_inner #n1 #n2 h0 h1 f a b c i j =
c.[i,j] <- f a.[i,j] b.[i,j]
/// In-place [map2], a == map2 f a b
///
/// A non in-place variant can be obtained by weakening the pre-condition to disjoint a c,
/// or the two variants can be merged by requiring (a == c \/ disjoint a c) instead of a == c
/// See commit 91916b8372fa3522061eff5a42d0ebd1d19a8a49
inline_for_extraction
val map2:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> f:(uint16 -> uint16 -> uint16)
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> c:matrix_t n1 n2
-> Stack unit
(requires fun h0 ->
live h0 a /\ live h0 b /\ live h0 c /\
a == c /\ disjoint b c)
(ensures fun h0 _ h1 -> modifies1 c h0 h1 /\
as_matrix h1 c == M.map2 #(v n1) #(v n2) f (as_matrix h0 a) (as_matrix h0 b))
let map2 #n1 #n2 f a b c =
let h0 = ST.get () in
Lib.Loops.for (size 0) n1
(fun h1 i -> live h1 a /\ live h1 b /\ live h1 c /\
modifies1 c h0 h1 /\ i <= v n1 /\
(forall (i0:nat{i0 < i}) (j:nat{j < v n2}).
get h1 c i0 j == f (get h0 a i0 j) (get h0 b i0 j)) /\
(forall (i0:nat{i <= i0 /\ i0 < v n1}) (j:nat{j < v n2}).
get h1 c i0 j == get h0 c i0 j) )
(fun i ->
let h1 = ST.get() in
Lib.Loops.for (size 0) n2
(fun h2 j -> map2_inner_inv h0 h1 h2 f a b c i j)
(fun j -> map2_inner h0 h1 f a b c i j)
);
let h2 = ST.get() in
M.extensionality (as_matrix h2 c) (M.map2 f (as_matrix h0 a) (as_matrix h0 b))
val matrix_add:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> Stack unit
(requires fun h -> live h a /\ live h b /\ disjoint a b)
(ensures fun h0 r h1 -> modifies1 a h0 h1 /\
as_matrix h1 a == M.add (as_matrix h0 a) (as_matrix h0 b))
[@"c_inline"]
let matrix_add #n1 #n2 a b =
map2 add_mod a b a
val matrix_sub:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> Stack unit
(requires fun h -> live h a /\ live h b /\ disjoint a b)
(ensures fun h0 r h1 -> modifies1 b h0 h1 /\
as_matrix h1 b == M.sub (as_matrix h0 a) (as_matrix h0 b))
[@"c_inline"]
let matrix_sub #n1 #n2 a b =
(* Use the in-place variant above by flipping the arguments of [sub_mod] *)
(* Requires appplying extensionality *)
let h0 = ST.get() in
[@ inline_let ]
let sub_mod_flipped x y = sub_mod y x in
map2 sub_mod_flipped b a b;
let h1 = ST.get() in
M.extensionality (as_matrix h1 b) (M.sub (as_matrix h0 a) (as_matrix h0 b))
#push-options "--fuel 1"
inline_for_extraction noextract private
val mul_inner:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> i:size_t{v i < v n1}
-> k:size_t{v k < v n3}
-> Stack uint16
(requires fun h -> live h a /\ live h b)
(ensures fun h0 r h1 -> modifies loc_none h0 h1 /\
r == M.mul_inner (as_matrix h0 a) (as_matrix h0 b) (v i) (v k))
let mul_inner #n1 #n2 #n3 a b i k =
push_frame();
let h0 = ST.get() in
[@ inline_let ]
let f l = get h0 a (v i) l *. get h0 b l (v k) in
let res = create #uint16 (size 1) (u16 0) in
let h1 = ST.get() in
Lib.Loops.for (size 0) n2
(fun h2 j -> live h1 res /\ live h2 res /\
modifies1 res h1 h2 /\
bget h2 res 0 == M.sum_ #(v n2) f j)
(fun j ->
let aij = a.[i,j] in
let bjk = b.[j,k] in
let res0 = res.(size 0) in
res.(size 0) <- res0 +. aij *. bjk
);
let res = res.(size 0) in
M.sum_extensionality (v n2) f (fun l -> get h0 a (v i) l *. get h0 b l (v k)) (v n2);
assert (res == M.mul_inner (as_matrix h0 a) (as_matrix h0 b) (v i) (v k));
pop_frame();
res
#pop-options
private unfold
val mul_inner_inv:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> h2:HS.mem
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> c:matrix_t n1 n3
-> f:(k:nat{k < v n3} -> GTot uint16)
-> i:size_t{v i < v n1}
-> k:size_nat
-> Type0
let mul_inner_inv #n1 #n2 #n3 h0 h1 h2 a b c f i k =
live h2 a /\ live h2 b /\ live h2 c /\
modifies1 c h1 h2 /\
k <= v n3 /\
(forall (i1:nat{i1 < v i}) (k:nat{k < v n3}). get h2 c i1 k == get h1 c i1 k) /\
(forall (k1:nat{k1 < k}). get h2 c (v i) k1 == f k1) /\
(forall (k1:nat{k <= k1 /\ k1 < v n3}). get h2 c (v i) k1 == get h0 c (v i) k1) /\
(forall (i1:nat{v i < i1 /\ i1 < v n1}) (k:nat{k < v n3}). get h2 c i1 k == get h0 c i1 k) /\
as_matrix h0 a == as_matrix h2 a /\
as_matrix h0 b == as_matrix h2 b
inline_for_extraction noextract private
val mul_inner1:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> c:matrix_t n1 n3{disjoint a c /\ disjoint b c}
-> i:size_t{v i < v n1}
-> k:size_t{v k < v n3}
-> f:(k:nat{k < v n3}
-> GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a (v i) l *. get h0 b l k)}))
-> Stack unit
(requires fun h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k))
(ensures fun _ _ h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k + 1))
let mul_inner1 #n1 #n2 #n3 h0 h1 a b c i k f =
assert (M.mul_inner (as_matrix h0 a) (as_matrix h0 b) (v i) (v k) ==
M.sum #(v n2) (fun l -> get h0 a (v i) l *. get h0 b l (v k)));
c.[i,k] <- mul_inner a b i k;
let h2 = ST.get () in
assert (get h2 c (v i) (v k) == f (v k))
private
val onemore: p:(nat -> Type0) -> q:(i:nat{p i} -> Type0) -> b:nat{p b} -> Lemma
(requires (forall (i:nat{p i /\ i < b}). q i) /\ q b)
(ensures forall (i:nat{p i /\ i <= b}). q i)
let onemore p q b = ()
val onemore1:
#n1:size_nat
-> #n3:size_nat{n1 * n3 <= max_size_t}
-> c:M.matrix n1 n3
-> f:(i:nat{i < n1} -> k:nat{k < n3} -> GTot uint16)
-> i:size_nat{i < n1} -> Lemma
(requires ((forall (i1:nat{i1 < i}) (k:nat{k < n3}). M.mget c i1 k == f i1 k) /\ (forall (k:nat{k < n3}). M.mget c i k == f i k)))
(ensures (forall (i1:nat{i1 <= i}) (k:nat{k < n3}). M.mget c i1 k == f i1 k))
let onemore1 #n1 #n3 c f i = ()
val matrix_mul:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> c:matrix_t n1 n3
-> Stack unit
(requires fun h ->
live h a /\ live h b /\ live h c /\
disjoint a c /\ disjoint b c)
(ensures fun h0 _ h1 -> modifies1 c h0 h1 /\
as_matrix h1 c == M.mul (as_matrix h0 a) (as_matrix h0 b))
[@"c_inline"]
let matrix_mul #n1 #n2 #n3 a b c =
let h0 = ST.get () in
let f (i:nat{i < v n1}) (k:nat{k < v n3}) :
GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a i l *. get h0 b l k)})
= M.sum #(v n2) (fun l -> get h0 a i l *. get h0 b l k)
in
Lib.Loops.for (size 0) n1
(fun h1 i ->
live h1 a /\ live h1 b /\ live h1 c /\
modifies1 c h0 h1 /\ i <= v n1 /\
(forall (i1:nat{i1 < i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k) /\
(forall (i1:nat{i <= i1 /\ i1 < v n1}) (k:nat{k < v n3}). get h1 c i1 k == get h0 c i1 k))
(fun i ->
let h1 = ST.get() in
Lib.Loops.for (size 0) n3
(fun h2 k -> mul_inner_inv h0 h1 h2 a b c (f (v i)) i k)
(fun k -> mul_inner1 h0 h1 a b c i k (f (v i)));
let h1 = ST.get() in
let q i1 = forall (k:nat{k < v n3}). get h1 c i1 k == f i1 k in
onemore (fun i1 -> i1 < v n1) q (v i);
assert (forall (i1:nat{i1 < v n1 /\ i1 <= v i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k)
);
let h2 = ST.get() in
M.extensionality (as_matrix h2 c) (M.mul (as_matrix h0 a) (as_matrix h0 b))
(* Special case of matrix multiplication *)
(* when we have a different way of accessing to entries of the matrix S *)
inline_for_extraction noextract
val mget_s:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> a:matrix_t n1 n2
-> i:size_t{v i < v n1}
-> j:size_t{v j < v n2}
-> Stack elem
(requires fun h0 -> live h0 a)
(ensures fun h0 x h1 -> modifies0 h0 h1 /\
x == M.mget_s (as_matrix h0 a) (v i) (v j))
let mget_s #n1 #n2 a i j =
M.index_lt (v n2) (v n1) (v j) (v i);
a.(j *! n1 +! i)
unfold
let get_s #n1 #n2 h (m:matrix_t n1 n2) i j = M.mget_s (as_matrix h m) i j
#push-options "--fuel 1"
inline_for_extraction noextract private
val mul_inner_s:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> i:size_t{v i < v n1}
-> k:size_t{v k < v n3}
-> Stack uint16
(requires fun h -> live h a /\ live h b)
(ensures fun h0 r h1 -> modifies0 h0 h1 /\
r == M.mul_inner_s (as_matrix h0 a) (as_matrix h0 b) (v i) (v k))
let mul_inner_s #n1 #n2 #n3 a b i k =
push_frame();
let h0 = ST.get() in
[@ inline_let ]
let f l = get h0 a (v i) l *. get_s h0 b l (v k) in
let res = create #uint16 (size 1) (u16 0) in
let h1 = ST.get() in
Lib.Loops.for (size 0) n2
(fun h2 j -> live h1 res /\ live h2 res /\
modifies1 res h1 h2 /\
bget h2 res 0 == M.sum_ #(v n2) f j)
(fun j ->
let aij = mget a i j in
let bjk = mget_s b j k in
let res0 = res.(size 0) in
res.(size 0) <- res0 +. aij *. bjk
);
let res = res.(size 0) in
M.sum_extensionality (v n2) f (fun l -> get h0 a (v i) l *. get_s h0 b l (v k)) (v n2);
assert (res == M.mul_inner_s (as_matrix h0 a) (as_matrix h0 b) (v i) (v k));
pop_frame();
res
#pop-options
inline_for_extraction noextract private
val mul_inner1_s:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> h0:HS.mem
-> h1:HS.mem
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> c:matrix_t n1 n3{disjoint a c /\ disjoint b c}
-> i:size_t{v i < v n1}
-> k:size_t{v k < v n3}
-> f:(k:nat{k < v n3}
-> GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a (v i) l *. get_s h0 b l k)}))
-> Stack unit
(requires fun h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k))
(ensures fun _ _ h2 -> mul_inner_inv h0 h1 h2 a b c f i (v k + 1))
let mul_inner1_s #n1 #n2 #n3 h0 h1 a b c i k f =
assert (M.mul_inner_s (as_matrix h0 a) (as_matrix h0 b) (v i) (v k) ==
M.sum #(v n2) (fun l -> get h0 a (v i) l *. get_s h0 b l (v k)));
mset c i k (mul_inner_s a b i k);
let h2 = ST.get () in
assert (get h2 c (v i) (v k) == f (v k))
val matrix_mul_s:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> #n3:size_t{v n2 * v n3 <= max_size_t /\ v n1 * v n3 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n2 n3
-> c:matrix_t n1 n3
-> Stack unit
(requires fun h ->
live h a /\ live h b /\ live h c /\
disjoint a c /\ disjoint b c)
(ensures fun h0 _ h1 -> modifies1 c h0 h1 /\
as_matrix h1 c == M.mul_s (as_matrix h0 a) (as_matrix h0 b))
[@"c_inline"]
let matrix_mul_s #n1 #n2 #n3 a b c =
let h0 = ST.get () in
let f (i:nat{i < v n1}) (k:nat{k < v n3}) :
GTot (res:uint16{res == M.sum #(v n2) (fun l -> get h0 a i l *. get_s h0 b l k)})
= M.sum #(v n2) (fun l -> get h0 a i l *. get_s h0 b l k)
in
Lib.Loops.for (size 0) n1
(fun h1 i ->
live h1 a /\ live h1 b /\ live h1 c /\
modifies1 c h0 h1 /\ i <= v n1 /\
(forall (i1:nat{i1 < i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k) /\
(forall (i1:nat{i <= i1 /\ i1 < v n1}) (k:nat{k < v n3}). get h1 c i1 k == get h0 c i1 k))
(fun i ->
let h1 = ST.get() in
Lib.Loops.for (size 0) n3
(fun h2 k -> mul_inner_inv h0 h1 h2 a b c (f (v i)) i k)
(fun k -> mul_inner1_s h0 h1 a b c i k (f (v i)));
let h1 = ST.get() in
let q i1 = forall k. get h1 c i1 k == f i1 k in
onemore (fun i1 -> i1 < v n1) q (v i);
assert (forall (i1:nat{i1 < v n1 /\ i1 <= v i}) (k:nat{k < v n3}). get h1 c i1 k == f i1 k)
);
let h2 = ST.get() in
M.extensionality (as_matrix h2 c) (M.mul_s (as_matrix h0 a) (as_matrix h0 b))
(* the end of the special matrix multiplication *)
val matrix_eq:
#n1:size_t
-> #n2:size_t{v n1 * v n2 <= max_size_t}
-> a:matrix_t n1 n2
-> b:matrix_t n1 n2
-> Stack uint16
(requires fun h -> live h a /\ live h b)
(ensures fun h0 r h1 -> modifies0 h0 h1 /\
r == M.matrix_eq #(v n1) #(v n2) (as_matrix h0 a) (as_matrix h0 b))
[@"c_inline"]
let matrix_eq #n1 #n2 a b =
push_frame();
let res = create 1ul (ones U16 SEC) in
let r = buf_eq_mask a b (n1 *! n2) res in
pop_frame ();
r
val matrix_to_lbytes:
#n1:size_t
-> #n2:size_t{2 * v n1 <= max_size_t /\ 2 * v n1 * v n2 <= max_size_t}
-> m:matrix_t n1 n2
-> res:lbytes (2ul *! n1 *! n2)
-> Stack unit
(requires fun h -> live h m /\ live h res /\ disjoint m res)
(ensures fun h0 r h1 -> modifies1 res h0 h1 /\
as_seq h1 res == M.matrix_to_lbytes #(v n1) #(v n2) (as_matrix h0 m))
[@"c_inline"]
let matrix_to_lbytes #n1 #n2 m res =
let h0 = ST.get () in
fill_blocks_simple h0 2ul (n1 *! n2) res
(fun h -> M.matrix_to_lbytes_f #(v n1) #(v n2) (as_matrix h0 m))
(fun i -> uint_to_bytes_le (sub res (2ul *! i) 2ul) m.(i))
val matrix_from_lbytes:
#n1:size_t
-> #n2:size_t{2 * v n1 <= max_size_t /\ 2 * v n1 * v n2 <= max_size_t}
-> b:lbytes (size 2 *! n1 *! n2)
-> res:matrix_t n1 n2
-> Stack unit
(requires fun h -> live h b /\ live h res /\ disjoint b res)
(ensures fun h0 _ h1 -> modifies1 res h0 h1 /\
as_matrix h1 res == M.matrix_from_lbytes (v n1) (v n2) (as_seq h0 b))
[@"c_inline"]
let matrix_from_lbytes #n1 #n2 b res =
let h0 = ST.get () in
assert (v n1 * v n2 <= max_size_t);
fill h0 (n1 *! n2) res
(fun h -> M.matrix_from_lbytes_f (v n1) (v n2) (as_seq h0 b))
(fun i ->
assert (2 * v i + 2 <= 2 * v n1 * v n2);
uint_from_bytes_le #U16 (sub b (2ul *! i) 2ul));
let h1 = ST.get () in
M.extensionality (as_matrix h1 res) (M.matrix_from_lbytes (v n1) (v n2) (as_seq h0 b))