Raw File
Hacl.Spec.Karatsuba.Lemmas.fst
module Hacl.Spec.Karatsuba.Lemmas

open FStar.Mul
open Lib.IntTypes

#set-options "--z3rlimit 50 --fuel 0 --ifuel 0"

type sign =
  | Positive
  | Negative

let abs (a:nat) (b:nat) : nat =
  if a < b then b - a else a - b

val sign_abs: a:nat -> b:nat ->
  Pure (tuple2 sign nat)
  (requires True)
  (ensures  fun (s, res) -> res == abs a b /\
    s == (if a < b then Negative else Positive))

let sign_abs a b =
  if a < b then (Negative, b - a) else (Positive, a - b)


val lemma_double_p: pbits:pos -> aLen:nat{aLen % 2 = 0} ->
  Lemma (let p = pow2 (aLen / 2 * pbits) in p * p == pow2 (pbits * aLen))
let lemma_double_p pbits aLen =
  let p = pow2 (aLen / 2 * pbits) in
  calc (==) {
    p * p;
    (==) { Math.Lemmas.pow2_plus (aLen / 2 * pbits) (aLen / 2 * pbits) }
    pow2 (aLen / 2 * pbits + aLen / 2 * pbits);
    (==) { Math.Lemmas.distributivity_add_left (aLen / 2) (aLen / 2) pbits }
    pow2 ((aLen / 2 * 2) * pbits);
    (==) { Math.Lemmas.lemma_div_exact aLen 2 }
    pow2 (aLen * pbits);
    }


val lemma_bn_halves: pbits:pos -> aLen:nat{aLen % 2 = 0} -> a:nat{a < pow2 (pbits * aLen)} ->
  Lemma (let p = pow2 (aLen / 2 * pbits) in a / p < p /\ a % p < p /\ a == a / p * p + a % p)
let lemma_bn_halves pbits aLen a = lemma_double_p pbits aLen


val lemma_middle_karatsuba: a0:nat -> a1:nat -> b0:nat -> b1:nat ->
  Lemma
   (let s0, t0 = sign_abs a0 a1 in
    let s1, t1 = sign_abs b0 b1 in
    let t23 = t0 * t1 in
    let t01 = a0 * b0 + a1 * b1 in
    let t45 = if s0 = s1 then t01 - t23 else t01 + t23 in
    t45 == a0 * b1 + a1 * b0)

let lemma_middle_karatsuba a0 a1 b0 b1 =
  let s0, t0 = sign_abs a0 a1 in
  let s1, t1 = sign_abs b0 b1 in
  let t23 = t0 * t1 in
  let t01 = a0 * b0 + a1 * b1 in
  let t45 = if s0 = s1 then t01 - t23 else t01 + t23 in
  if s0 = s1 then
    assert (t45 = a0 * b0 + a1 * b1 - (a0 - a1) * (b0 - b1))
  else
    assert (t45 = a0 * b0 + a1 * b1 + (a1 - a0) * (b0 - b1))


val lemma_karatsuba: pbits:pos -> aLen:nat{aLen % 2 = 0} -> a0:nat -> a1:nat -> b0:nat -> b1:nat -> Lemma
  (let aLen2 = aLen / 2 in
   let p = pow2 (pbits * aLen2) in
   let a = a1 * p + a0 in
   let b = b1 * p + b0 in
   a1 * b1 * pow2 (pbits * aLen) + (a0 * b1 + a1 * b0) * pow2 (pbits * aLen2) + a0 * b0 == a * b)

let lemma_karatsuba pbits aLen a0 a1 b0 b1 =
  let aLen2 = aLen / 2 in
  let p = pow2 (pbits * aLen2) in
  let a = a1 * p + a0 in
  let b = b1 * p + b0 in

  calc (==) {
    a * b;
    (==) { }
    (a1 * p + a0) * (b1 * p + b0);
    (==) { }
    a1 * p * (b1 * p + b0) + a0 * (b1 * p + b0);
    (==) { }
    a1 * p * (b1 * p) + a1 * p * b0 + a0 * (b1 * p) + a0 * b0;
    (==) { Math.Lemmas.paren_mul_right a0 b1 p }
    a1 * p * (b1 * p) + a1 * p * b0 + a0 * b1 * p + a0 * b0;
    (==) { Math.Lemmas.paren_mul_right a1 p (b1 * p); Math.Lemmas.paren_mul_right p p b1 }
    a1 * (b1 * (p * p)) + a1 * p * b0 + a0 * b1 * p + a0 * b0;
    (==) { Math.Lemmas.paren_mul_right a1 b1 (p * p) }
    a1 * b1 * (p * p) + a1 * p * b0 + a0 * b1 * p + a0 * b0;
    (==) { lemma_double_p pbits aLen }
    a1 * b1 * pow2 (pbits * aLen) + a1 * p * b0 + a0 * b1 * p + a0 * b0;
    (==) { Math.Lemmas.paren_mul_right a1 p b0; Math.Lemmas.paren_mul_right a1 b0 p }
    a1 * b1 * pow2 (pbits * aLen) + a1 * b0 * p + a0 * b1 * p + a0 * b0;
    (==) { Math.Lemmas.distributivity_add_left (a1 * b0) (a0 * b1) p }
    a1 * b1 * pow2 (pbits * aLen) + (a1 * b0 + a0 * b1) * p + a0 * b0;
   }


val karatsuba:
    pbits:pos // pbits = bits t
  -> aLen:nat
  -> a:nat{a < pow2 (pbits * aLen)}
  -> b:nat{b < pow2 (pbits * aLen)} ->
  Tot (res:nat{res == a * b}) (decreases aLen)

let rec karatsuba pbits aLen a b =
  if aLen < 16 || aLen % 2 = 1 then a * b
  else begin
    let aLen2 = aLen / 2 in
    let p = pow2 (aLen2 * pbits) in

    let a0 = a % p in let a1 = a / p in
    let b0 = b % p in let b1 = b / p in
    lemma_bn_halves pbits aLen a;
    lemma_bn_halves pbits aLen b;

    let s0, t0 = sign_abs a0 a1 in
    let s1, t1 = sign_abs b0 b1 in

    let t23 = karatsuba pbits aLen2 t0 t1 in assert (t23 == t0 * t1);
    let r01 = karatsuba pbits aLen2 a0 b0 in assert (r01 == a0 * b0);
    let r23 = karatsuba pbits aLen2 a1 b1 in assert (r23 == a1 * b1);

    let t01 = r01 + r23 in assert (t01 == a0 * b0 + a1 * b1);
    let t45 = if s0 = s1 then t01 - t23 else t01 + t23 in
    lemma_middle_karatsuba a0 a1 b0 b1;
    assert (t45 == a0 * b1 + a1 * b0);

    let res = r23 * pow2 (pbits * aLen) + t45 * p + r01 in
    lemma_karatsuba pbits aLen a0 a1 b0 b1;
    res end
back to top