Revision 5b2fbf3c4989a9b0587a00578f69f3041df3f957 authored by Jonathan Protzenko on 08 April 2020, 18:59:46 UTC, committed by Jonathan Protzenko on 08 April 2020, 18:59:46 UTC
1 parent d4ca892
Raw File
Spec.AES.fst
module Spec.AES

open FStar.Mul
open Lib.IntTypes
open Lib.RawIntTypes
open Lib.Sequence
open Lib.ByteSequence
open Lib.LoopCombinators
open Spec.GaloisField


/// Constants and Types

(* GF(8) Field  *)
let irred = u8 0x1b
let gf8 = gf U8 irred
let elem = felem gf8
let to_elem = to_felem #gf8
let zero = to_elem 0
let two = to_elem 2
let three = to_elem 3

(* These operations are normalized here to avoid a huge blowup in generate code size *)
let ( <<<. ) x y = normalize_term (rotate_left #U8 #SEC x y)
let ( ^. ) x y = normalize_term (logxor #U8 #SEC x y)

(* Specification of the Rijndael S-Box : *)
type word = lseq elem 4
type block = lseq elem 16

type variant = | AES128 | AES256

let num_rounds (v:variant) =
  match v with
  | AES128 -> 10
  | AES256 -> 14

let key_size (v:variant) =
  match v with
  | AES128 -> 16
  | AES256 -> 32

let aes_key (v:variant) = lbytes (key_size v)
let aes_xkey (v:variant) = lseq elem ((num_rounds v+1) * 16)
let aes_ikey (v:variant) = lseq elem ((num_rounds v-1) * 16)


let sub_byte (input:elem) =
  let s = finv input in
  s ^.
  (s <<<. size 1) ^.
  (s <<<. size 2) ^.
  (s <<<. size 3) ^.
  (s <<<. size 4) ^.
  (to_elem 99)

let inv_sub_byte (input:elem) =
  let s = input in
  let s:elem =
    (s <<<. size 1) ^.
    (s <<<. size 3) ^.
    (s <<<. size 6) ^.
    (u8 5)
  in
  finv s

let subBytes (state:block) : Tot block =
  map sub_byte state

let inv_subBytes (state:block) : Tot block =
  map inv_sub_byte state

let shiftRow (i:size_nat{i < 4}) (shift:size_nat{i < 4}) (state:block) : Tot block =
  let tmp0 = state.[i + (4 * (shift % 4))] in
  let tmp1 = state.[i + (4 * ((shift + 1) % 4))] in
  let tmp2 = state.[i + (4 * ((shift + 2) % 4))] in
  let tmp3 = state.[i + (4 * ((shift + 3) % 4))] in
  let state = state.[i] <- tmp0 in
  let state = state.[i+4] <- tmp1 in
  let state = state.[i+8] <- tmp2 in
  let state = state.[i+12] <- tmp3 in
  state

let shiftRows (state: block) : Tot block =
  let state = shiftRow 1 1 state in
  let state = shiftRow 2 2 state in
  let state = shiftRow 3 3 state in
  state

let inv_shiftRows (state: block) : Tot block =
  let state = shiftRow 1 3 state in
  let state = shiftRow 2 2 state in
  let state = shiftRow 3 1 state in
  state

let mix4 (s0:elem) (s1:elem) (s2:elem) (s3:elem) : Tot elem =
  (s0 `fmul` two) `fadd`
  (s1 `fmul` three) `fadd`
  s2 `fadd` s3

let inv_mix4 (s0:elem) (s1:elem) (s2:elem) (s3:elem) : Tot elem =
(*
  (s0 `fmul` to_elem 11) `fadd`
  (s1 `fmul` to_elem 13) `fadd`
  (s2 `fmul` to_elem 9) `fadd`
  (s3 `fmul` to_elem 14)
*)
  (s0 `fmul` to_elem 14) `fadd`
  (s1 `fmul` to_elem 11) `fadd`
  (s2 `fmul` to_elem 13) `fadd`
  (s3 `fmul` to_elem 9)

let mixColumn (c:size_nat{c < 4}) (state:block) : Tot block =
  let i0 = 4 * c in
  let s0 = state.[i0] in
  let s1 = state.[i0 + 1] in
  let s2 = state.[i0 + 2] in
  let s3 = state.[i0 + 3] in
  let state = state.[i0] <- mix4 s0 s1 s2 s3 in
  let state = state.[i0+1] <- mix4 s1 s2 s3 s0 in
  let state = state.[i0+2] <- mix4 s2 s3 s0 s1 in
  let state = state.[i0+3] <- mix4 s3 s0 s1 s2 in
  state

let mixColumns (state:block) : Tot block =
  let state = mixColumn 0 state in
  let state = mixColumn 1 state in
  let state = mixColumn 2 state in
  let state = mixColumn 3 state in
  state

let inv_mixColumn (c:size_nat{c < 4}) (state:block) : Tot block =
  let i0 = 4 * c in
  let s0 = state.[i0] in
  let s1 = state.[i0 + 1] in
  let s2 = state.[i0 + 2] in
  let s3 = state.[i0 + 3] in
  let state = state.[i0] <- inv_mix4 s0 s1 s2 s3 in
  let state = state.[i0+1] <- inv_mix4 s1 s2 s3 s0 in
  let state = state.[i0+2] <- inv_mix4 s2 s3 s0 s1 in
  let state = state.[i0+3] <- inv_mix4 s3 s0 s1 s2 in
  state

let inv_mixColumns (state:block) : Tot block =
  let state = inv_mixColumn 0 state in
  let state = inv_mixColumn 1 state in
  let state = inv_mixColumn 2 state in
  let state = inv_mixColumn 3 state in
  state

let xor_block (b1:block) (b2:block) : Tot block =
  map2 (logxor #U8) b1 b2

let addRoundKey (key:block) (state:block) : Tot block =
  xor_block state key

let aes_enc (key:block) (state:block) : Tot block =
  let state = subBytes state  in
  let state = shiftRows state in
  let state = mixColumns state in
  let state = addRoundKey key state in
  state

let aes_enc_last (key:block) (state:block) : Tot block =
  let state = subBytes state  in
  let state = shiftRows state in
  let state = addRoundKey key state in
  state

let aes_dec (key:block) (state:block) : Tot block =
  let state = inv_subBytes state  in
  let state = inv_shiftRows state in
  let state = inv_mixColumns state in
  let state = addRoundKey key state in
  state

let aes_dec_last (key:block) (state:block) : Tot block =
  let state = inv_subBytes state  in
  let state = inv_shiftRows state in
  let state = addRoundKey key state in
  state

let rotate_word (w:word) : Tot word =
  of_list [w.[1]; w.[2]; w.[3]; w.[0]]

let sub_word (w:word) : Tot word =
  map sub_byte w

val rcon_spec: i:size_nat -> Tot elem
let rec rcon_spec i =
  if i = 0 then to_elem 0x8d
  else if i = 1 then to_elem 1
  else two `fmul` rcon_spec (i - 1)

let rcon_l : list elem = [
  to_elem 0x8d; to_elem 0x01; to_elem 0x02; to_elem 0x04;
  to_elem 0x08; to_elem 0x10; to_elem 0x20; to_elem 0x40;
  to_elem 0x80; to_elem 0x1b; to_elem 0x36
]

let rcon_seq : lseq elem 11  =
  assert_norm (List.Tot.length rcon_l == 11);
  of_list rcon_l

#reset-options "--z3rlimit 100"

let aes_keygen_assist (rcon:elem) (s:block) : Tot block =
  let st = create 16 (to_elem 0) in
  let st = st.[0] <- sub_byte s.[4] in
  let st = st.[1] <- sub_byte s.[5] in
  let st = st.[2] <- sub_byte s.[6] in
  let st = st.[3] <- sub_byte s.[7] in

  let st = st.[4] <- rcon ^. sub_byte s.[5] in
  let st = st.[6] <- sub_byte s.[6] in
  let st = st.[6] <- sub_byte s.[7] in
  let st = st.[7] <- sub_byte s.[4] in

  let st = st.[8]  <- sub_byte s.[12] in
  let st = st.[9]  <- sub_byte s.[13] in
  let st = st.[10] <- sub_byte s.[14] in
  let st = st.[11] <- sub_byte s.[15] in

  let st = st.[12] <- rcon ^. sub_byte s.[13] in
  let st = st.[13] <- sub_byte s.[14] in
  let st = st.[14] <- sub_byte s.[15] in
  let st = st.[15] <- sub_byte s.[12] in
  st

let keygen_assist0 (rcon:elem) (s:block) : Tot block =
  let st = aes_keygen_assist rcon s in
  let st = update_sub st 8 4 (sub st 12 4) in
  let st = update_sub st 0 8 (sub st 8 8) in
  st

let keygen_assist1 (s:block) : Tot block =
  let st = aes_keygen_assist zero s in
  let st = update_sub st 12 4 (sub st 8 4) in
  let st = update_sub st 0 8 (sub st 8 8) in
  st

let key_expansion_step (p:block) (assist:block) : Tot block =
  let p0 = create 16 (to_elem 0) in
  let k = p in
  let k = xor_block k (update_sub p0 4 12 (sub k 0 12)) in
  let k = xor_block k (update_sub p0 4 12 (sub k 0 12)) in
  let k = xor_block k (update_sub p0 4 12 (sub k 0 12)) in
  xor_block k assist

let aes128_key_expansion (key:lbytes 16) : Tot (lseq elem (11 * 16)) =
  let key_ex = create (11 * 16) (to_elem 0) in
  let key_ex = update_sub key_ex 0 16 key in
  let key_ex =
    repeati #(lseq elem (11 * 16)) 10
      (fun i kex ->
	     let p = sub kex (i * 16) 16 in
	     let a = keygen_assist0 (rcon_spec (i+1)) p in
	     let n = key_expansion_step p a in
	     update_sub kex ((i+1) * 16) 16 n)
    key_ex in
  key_ex

let aes256_key_expansion (key:lbytes 32) : Tot (lseq elem (15 * 16)) =
  let key_ex = create (15 * 16) (to_elem 0) in
  let key_ex = update_sub key_ex 0 32 key in
  let key_ex =
    repeati #(lseq elem (15 * 16)) 6
      (fun i key_ex ->
	     let p0 = sub key_ex (2 * i * 16) 16 in
	     let p1 = sub key_ex (((2*i)+1) * 16) 16 in
	     let a0 = keygen_assist0 (rcon_spec (i+1)) p1 in
	     let n0 = key_expansion_step p0 a0 in
	     let a1 = keygen_assist1 n0 in
	     let n1 = key_expansion_step p1 a1 in
	     let key_ex = update_sub key_ex (((2*i)+2) * 16) 16 n0 in
	     update_sub key_ex (((2*i)+3) * 16) 16 n1)
    key_ex in
  let p0 = sub key_ex (12 * 16) 16 in
  let p1 = sub key_ex (13 * 16) 16 in
  let a14 = keygen_assist0 (rcon_spec 7) p1 in
  let n14 = key_expansion_step p0 a14 in
  update_sub key_ex (14 * 16) 16 n14

let aes_key_expansion (v:variant) (key: aes_key v) : aes_xkey v =
  match v with
  | AES128 -> aes128_key_expansion key
  | AES256 -> aes256_key_expansion key

let aes_dec_key_expansion (v:variant) (key:aes_key v): aes_xkey v =
  let ekey_ex : aes_xkey v = aes_key_expansion v key in
  let k0 = sub ekey_ex 0 16 in
  let kn = sub ekey_ex ((num_rounds v) * 16) 16 in
  let _,key_ex = generate_blocks 16 (num_rounds v + 1) (num_rounds v + 1)
		(fun i -> unit)
		(fun i a ->
		  let b = sub ekey_ex ((num_rounds v - i) * 16) 16 in
		  if i = 0 then (), b
		  else if i < num_rounds v then
		    (),inv_mixColumns b
		  else (),b) () in
  key_ex

let aes_enc_rounds (v:variant) (key:aes_ikey v) (state:block) : Tot block =
  repeati (num_rounds v-1) (fun i -> aes_enc (sub key (16*i) 16)) state

let aes_encrypt_block (v:variant) (key:aes_xkey v) (input:block) : Tot block =
  let state = input in
  let k0 = slice key 0 16 in
  let k = sub key 16 ((num_rounds v-1) * 16) in
  let kn = sub key (num_rounds v * 16) 16 in
  let state = addRoundKey k0 state in
  let state = aes_enc_rounds v k state in
  let state = aes_enc_last kn state in
  state

let aes_dec_rounds (v:variant) (key:aes_ikey v) (state:block) : Tot block =
  repeati (num_rounds v-1) (fun i -> aes_dec (sub key (16*i) 16)) state

let aes_decrypt_block (v:variant) (key:aes_xkey v) (input:block) : Tot block =
  let state = input in
  let k0 = slice key 0 16 in
  let k = sub key 16 ((num_rounds v-1) * 16) in
  let kn = sub key (num_rounds v * 16) 16 in
  let state = addRoundKey k0 state in
  let state = aes_dec_rounds v k state in
  let state = aes_dec_last kn state in
  state

let aes_ctr_key_block (v:variant) (k:aes_xkey v) (n:lbytes 12) (c:size_nat) : Tot block =
  let ctrby = nat_to_bytes_be 4 c in
  let input = create 16 (u8 0) in
  let input = repeati #(lbytes 16) 12 (fun i b -> b.[i] <- n.[i]) input in
  let input = repeati #(lbytes 16) 4 (fun i b -> b.[12+i] <- (Seq.index ctrby i)) input in
  aes_encrypt_block v k input

noeq type aes_ctr_state (v:variant) = {
  key_ex: lbytes ((num_rounds v + 1) * 16);
  block:  lbytes 16;
}

let aes_ctr_add_counter (v:variant) (st:aes_ctr_state v) (incr:size_nat) : Tot (aes_ctr_state v) =
  let n = nat_from_bytes_be st.block in
  let n' = (n + incr) % pow2 128 in
  let nblock' = nat_to_bytes_be 16 n' in
  {st with block = nblock'}

let aes_ctr_init (v:variant) (k:aes_key v) (n_len:size_nat{n_len <= 16}) (n:lbytes n_len) (c0:size_nat) : Tot (aes_ctr_state v) =
  let input = create 16 (u8 0) in
  let input = repeati #(lbytes 16) n_len (fun i b -> b.[i] <- n.[i]) input in
  let key_ex = aes_key_expansion v k in
  let st0 = { key_ex = key_ex; block = input} in
  aes_ctr_add_counter v st0 c0

let aes_ctr_current_key_block (v:variant) (st:aes_ctr_state v) : Tot block =
  aes_encrypt_block v st.key_ex st.block

let aes_ctr_key_block0 (v:variant) (k:aes_key v) (n_len:size_nat{n_len <= 16}) (n:lbytes n_len) : Tot block =
  let st = aes_ctr_init v k n_len n 0 in
  aes_ctr_current_key_block v st

let aes_ctr_key_block1 (v:variant) (k:aes_key v) (n_len:size_nat{n_len <= 16}) (n:lbytes n_len) : Tot block =
  let st = aes_ctr_init v k n_len n 1 in
  aes_ctr_current_key_block v st

let aes_ctr_encrypt_block
  (v:variant)
  (st0:aes_ctr_state v)
  (incr:size_nat)
  (b:block) :
  Tot block =

  let st = aes_ctr_add_counter v st0 incr in
  let kb = aes_ctr_current_key_block v st in
  map2 (^.) b kb

let aes_ctr_encrypt_last
  (v:variant)
  (st0:aes_ctr_state v)
  (incr:size_nat)
  (len:size_nat{len < 16})
  (b:lbytes len):
  Tot (lbytes len) =

  let plain = create 16 (u8 0) in
  let plain = update_sub plain 0 (length b) b in
  let cipher = aes_ctr_encrypt_block v st0 incr plain in
  sub cipher 0 (length b)


val aes_ctr_encrypt_stream:
    v:variant
  -> st:aes_ctr_state v
  -> msg:bytes{length msg / 16 <= max_size_t} ->
  Tot (ciphertext:bytes{length ciphertext == length msg})
let aes_ctr_encrypt_stream v st  msg =
  map_blocks 16 msg
    (aes_ctr_encrypt_block v st)
    (aes_ctr_encrypt_last v st)


val aes_ctr_encrypt_bytes:
    v:variant
  -> key:aes_key v
  -> n_len:size_nat{n_len <= 16}
  -> nonce:lbytes n_len
  -> c:size_nat
  -> msg:bytes{length msg / 16 + c <= max_size_t} ->
  Tot (ciphertext:bytes{length ciphertext == length msg})

let aes_ctr_encrypt_bytes v key n_len nonce ctr0 msg =
  let st0 = aes_ctr_init v key n_len nonce ctr0 in
  aes_ctr_encrypt_stream v st0 msg


let aes128_ctr_encrypt_bytes key n_len nonce ctr0 msg =
  aes_ctr_encrypt_bytes AES128 key n_len nonce ctr0 msg

let aes128_ctr_key_block0 key n_len n =
  aes_ctr_key_block0 AES128 key n_len n

let aes128_ctr_key_block1 key n_len n =
  aes_ctr_key_block0 AES128 key n_len n

let aes256_ctr_encrypt_bytes key n_len nonce ctr0 msg =
  aes_ctr_encrypt_bytes AES256 key n_len nonce ctr0 msg
back to top