Raw File
Hacl.Impl.Curve25519.Field64.Hacl.fst
module Hacl.Impl.Curve25519.Field64.Hacl

open FStar.HyperStack
open FStar.HyperStack.All
open FStar.Mul

open Lib.Sequence
open Lib.IntTypes
open Lib.Buffer

module B = Lib.Buffer
module ST = FStar.HyperStack.ST
module BSeq = Lib.ByteSequence
module LSeq = Lib.Sequence

module CD = Hacl.Spec.Curve25519.Field64.Definition
module CC = Hacl.Spec.Curve25519.Field64.Core

module BN = Hacl.Bignum
module BD = Hacl.Bignum.Definitions
module SB = Hacl.Spec.Bignum
module SD = Hacl.Spec.Bignum.Definitions

#set-options "--z3rlimit 50 --fuel 0 --ifuel 0"


[@CInline]
let add1_ out f1 f2 =
  let h0 = ST.get () in
  let c = BN.bn_add1 4ul f1 f2 out in
  let h1 = ST.get () in
  assert (let c1, r1 = CC.add1 (as_seq h0 f1) f2 in c == c1 /\ as_seq h1 out == r1);
  CD.bn_v_is_as_nat (as_seq h0 f1);
  CD.bn_v_is_as_nat (as_seq h1 out);
  c


[@CInline]
let fadd_ out f1 f2 =
  let h0 = ST.get () in
  let c0 = BN.bn_add_eq_len 4ul f1 f2 out in
  let c = BN.bn_add1 4ul out (c0 *. u64 38) out in
  out.(0ul) <- out.(0ul) +. c *. u64 38;
  let h1 = ST.get () in
  assert (as_seq h1 out == CC.fadd4 (as_seq h0 f1) (as_seq h0 f2));
  CD.bn_v_is_as_nat (as_seq h0 f1);
  CD.bn_v_is_as_nat (as_seq h0 f2);
  CD.bn_v_is_as_nat (as_seq h1 out)


[@CInline]
let fsub_ out f1 f2 =
  let h0 = ST.get () in
  let c0 = BN.bn_sub_eq_len 4ul f1 f2 out in
  let c = BN.bn_sub1 4ul out (c0 *! u64 38) out in
  out.(0ul) <- out.(0ul) -. c *. u64 38;
  let h1 = ST.get () in
  assert (as_seq h1 out == CC.fsub4 (as_seq h0 f1) (as_seq h0 f2));
  CD.bn_v_is_as_nat (as_seq h0 f1);
  CD.bn_v_is_as_nat (as_seq h0 f2);
  CD.bn_v_is_as_nat (as_seq h1 out)


[@CInline]
let fmul_ out f1 f2 tmp =
  let h0 = ST.get () in
  let tmp0 = sub tmp 0ul 8ul in
  BN.bn_mul 4ul 4ul f1 f2 tmp0;
  let c0 = BN.bn_mul1_lshift_add_in_place 4ul (sub tmp0 4ul 4ul) (u64 38) 4ul 0ul (sub tmp0 0ul 4ul) in
  let c = BN.bn_add1 4ul (sub tmp0 0ul 4ul) (c0 *. u64 38) out in
  out.(0ul) <- out.(0ul) +. c *. u64 38;
  let h1 = ST.get () in
  assert (as_seq h1 out == CC.fmul4 (as_seq h0 f1) (as_seq h0 f2));
  CD.bn_v_is_as_nat (as_seq h0 f1);
  CD.bn_v_is_as_nat (as_seq h0 f2);
  CD.bn_v_is_as_nat (as_seq h1 out)


[@CInline]
let fmul2_ out f1 f2 tmp =
  let out1 = B.sub out 0ul 4ul in
  let out2 = B.sub out 4ul 4ul in
  let f11 = B.sub f1 0ul 4ul in
  let f12 = B.sub f1 4ul 4ul in
  let f21 = B.sub f2 0ul 4ul in
  let f22 = B.sub f2 4ul 4ul in
  fmul_ out1 f11 f21 tmp;
  fmul_ out2 f12 f22 tmp


[@CInline]
let fmul1_ out f1 f2 =
  let h0 = ST.get () in
  let c0 = BN.bn_mul1 4ul f1 f2 out in
  let c = BN.bn_add1 4ul out (c0 *. u64 38) out in
  out.(0ul) <- out.(0ul) +. c *. u64 38;
  let h1 = ST.get () in
  assert (as_seq h1 out == CC.fmul14 (as_seq h0 f1) f2);
  CD.bn_v_is_as_nat (as_seq h0 f1);
  CD.bn_v_is_as_nat (as_seq h1 out)


[@CInline]
let fsqr_ out f1 tmp =
  let h0 = ST.get () in
  BN.bn_sqr 4ul f1 tmp;
  SB.bn_sqr_lemma (as_seq h0 f1);
  let c0 = BN.bn_mul1_lshift_add_in_place 4ul (sub tmp 4ul 4ul) (u64 38) 4ul 0ul (sub tmp 0ul 4ul) in
  let c = BN.bn_add1 4ul (sub tmp 0ul 4ul) (c0 *. u64 38) out in
  out.(0ul) <- out.(0ul) +. c *. u64 38;
  let h1 = ST.get () in
  assert (as_seq h1 out == CC.fsqr4 (as_seq h0 f1));
  CD.bn_v_is_as_nat (as_seq h0 f1);
  CD.bn_v_is_as_nat (as_seq h1 out)


[@CInline]
let fsqr2_ out f tmp =
  let out1 = B.sub out 0ul 4ul in
  let out2 = B.sub out 4ul 4ul in
  let f1 = B.sub f 0ul 4ul in
  let f2 = B.sub f 4ul 4ul in
  fmul_ out1 f1 f1 tmp;
  fmul_ out2 f2 f2 tmp


[@CInline]
let cswap2_ bit p1 p2 =
  let h0 = ST.get () in
  let mask = u64 0 -. bit in

  [@ inline_let]
  let inv h1 (i:nat{i <= 8}) =
    (forall (k:nat{k < i}).
      if v bit = 1
      then (as_seq h1 p1).[k] == (as_seq h0 p2).[k] /\ (as_seq h1 p2).[k] == (as_seq h0 p1).[k]
      else (as_seq h1 p1).[k] == (as_seq h0 p1).[k] /\ (as_seq h1 p2).[k] == (as_seq h0 p2).[k]) /\
    (forall (k:nat{i <= k /\ k < 8}).
      (as_seq h1 p1).[k] == (as_seq h0 p1).[k] /\ (as_seq h1 p2).[k] == (as_seq h0 p2).[k]) /\
    modifies (loc p1 |+| loc p2) h0 h1 in

  Lib.Loops.for 0ul 8ul inv
    (fun i ->
      let dummy = mask &. (p1.(i) ^. p2.(i)) in
      p1.(i) <- p1.(i) ^. dummy;
      p2.(i) <- p2.(i) ^. dummy;
      Hacl.Spec.Bignum.Lib.lemma_cswap2_step bit ((as_seq h0 p1).[v i]) ((as_seq h0 p2).[v i])
    );
  let h1 = ST.get () in
  assert (if v bit = 1 then (eq_intro (as_seq h1 p1) (as_seq h0 p2); as_seq h1 p1 == as_seq h0 p2)
    else (eq_intro (as_seq h1 p1) (as_seq h0 p1); as_seq h1 p1 == as_seq h0 p1));
  assert (if v bit = 1 then (eq_intro (as_seq h1 p2) (as_seq h0 p1); as_seq h1 p2 == as_seq h0 p1)
    else (eq_intro (as_seq h1 p2) (as_seq h0 p2); as_seq h1 p2 == as_seq h0 p2))
back to top