Revision e9e27cce8a171ce033cd42abad598c27b07df555 authored by Valentin Chaboche on 13 September 2023, 10:43:35 UTC, committed by Marge Bot on 22 September 2023, 07:41:44 UTC
1 parent 262d346
Raw File
mod_arith_gates.ml
(*****************************************************************************)
(*                                                                           *)
(* MIT License                                                               *)
(* Copyright (c) 2023 Nomadic Labs <contact@nomadic-labs.com>                *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

open Bls
open Identities
module L = Plompiler.LibCircuit
open Gates_common

(* Modular addition over a non-native modulus:
   Non Arith
   degree : 2
   nb identities : 1 + |MOD_ARITH.moduli_add|
   advice selectors : None
   equations : see lib_plompiler/gadget_mod_arith.ml *)
module Make_ModAdd (MOD_ARITH : Plompiler__Gadget_mod_arith.MOD_ARITH) :
  Base_sig = struct
  module M = MOD_ARITH (L)

  let q_label = "q_mod_add_" ^ M.label

  let ( %! ) = Z.rem

  (* In the first row:
     M.nb_limbs correspond to the first input.
     M.nb_limbs correspond to the second input.

     In the second row:
     M.nb_limbs correspond to the output.
     1 corresponds to qm variable (quotient by the main modulus).
     M.moduli_add correspond to tj (quotionts by the auxiliary moduli). *)
  let nb_used_wires =
    let used_fst_row = 2 * M.nb_limbs in
    let used_snd_row = M.nb_limbs + 1 + List.length M.moduli_add in
    let nb_used_wires = Int.max used_fst_row used_snd_row in
    assert (nb_used_wires <= Plompiler.Csir.nb_wires_arch) ;
    nb_used_wires

  (* powers of the base modulo the modulus *)
  let bs_mod_m =
    List.init M.nb_limbs (fun i -> Z.pow M.base i %! M.modulus) |> List.rev

  let (qm_shift, _), ts_bounds = M.bounds_add

  (* There are as many identities as moduli + 1, as we also have an identity
     on the native modulus *)
  let identity = (q_label, 1 + List.length M.moduli_add)

  let index_com = None

  let nb_advs = 0

  let nb_buffers = 3

  let gx_composition = true

  let polynomials_degree =
    (q_label, 2) :: List.init nb_used_wires (fun i -> (wire_name i, 2))
    |> SMap.of_list

  let get_values wires wires_g =
    let xs = List.init M.nb_limbs (fun i -> wires.(i)) in
    let ys = List.init M.nb_limbs (fun i -> wires.(M.nb_limbs + i)) in
    let zs = List.init M.nb_limbs (fun i -> wires_g.(i)) in
    let qm = wires_g.(M.nb_limbs) in
    let ts = List.mapi (fun i _ -> wires_g.(M.nb_limbs + 1 + i)) M.moduli_add in
    let t_infos =
      List.map2 (fun tj (t_shift, _) -> Some (tj, t_shift)) ts ts_bounds
    in
    (xs, ys, zs, qm, t_infos)

  let equations ~q:q_mod_add ~wires ~wires_g ?precomputed_advice:_ () =
    (* z = (x + y) mod m
       let k := nb_limbs and n := |moduli|

       PlonK wires distribution:
        row i   : x0 ... x_{k-1} y0 ... y_{k-1}
        row i+1 : z0 ... z_{k-1} qm t1 ... t_{n}
    *)
    let xs, ys, zs, qm, t_infos = get_values wires wires_g in
    let sum = List.fold_left Scalar.add Scalar.zero in
    List.map2
      (fun mj t_info ->
        (* \sum_i ((B^i mod m) mod mj) * (x_i + y_i - z_i)
           - qm * (m mod mj) - ((qm_shift * m) mod mj) = (tj + tj_shift) * mj *)
        let tj, tj_shift =
          match t_info with
          | Some (tj, tj_shift) -> (tj, tj_shift)
          | None -> (Scalar.zero, Z.zero)
        in
        let id_mj =
          let open Scalar in
          sum
            (List.map2
               (fun bi_mod_m ((xi, yi), zi) ->
                 of_z (bi_mod_m %! mj) * (xi + yi + negate zi))
               bs_mod_m
               (List.combine (List.combine xs ys) zs))
          + negate (qm * of_z (M.modulus %! mj))
          + negate (of_z Z.(qm_shift * M.modulus %! mj))
          + negate ((tj + of_z tj_shift) * of_z mj)
        in
        Scalar.(q_mod_add * id_mj))
      (Scalar.order :: M.moduli_add)
      (None :: t_infos)

  let prover_identities ~prefix_common ~prefix ~public:_ ~domain :
      prover_identities =
   fun evaluations ->
    let domain_size = Domain.length domain in
    let tmps, ids = get_buffers ~nb_buffers ~nb_ids:(snd identity) in
    let ({q; wires} : witness) =
      get_evaluations ~q_label ~prefix ~prefix_common evaluations
    in
    let q_mod_add = q in
    (* Note that in the prover we do not have wires_g, so we will need to
       compose qm & ts with gX *)
    let xs, ys, _zs, qm, t_infos = get_values wires wires in
    List.mapi
      (fun i (mj, t_info) ->
        (* id_mj :=
           \sum_i ((B^i mod m) mod mj) * (x_i + y_i - z_i)
           - qm * (m mod mj) - ((qm_shift * m) mod mj) - (tj + tj_shift) * mj *)
        let id_mj_without_sum =
          (* In the case of the native modulus, we can ignore the
             (tj + tj_shift) component *)
          let tj, tj_coeff, tj_shift, tj_comp =
            match t_info with
            | Some (tj, tj_shift) ->
                ([tj], Scalar.[negate (of_z mj)], tj_shift, [1])
            | None -> ([], [], Z.zero, [])
          in
          Evaluations.linear_c
            ~res:tmps.(0)
            ~evaluations:(qm :: tj)
            ~composition_gx:(1 :: tj_comp, domain_size)
            ~linear_coeffs:(Scalar.(negate (of_z (M.modulus %! mj))) :: tj_coeff)
            ~add_constant:
              Scalar.(
                negate (of_z Z.((qm_shift * M.modulus %! mj) + (tj_shift * mj))))
            ()
        in
        let id_mj =
          List.fold_left2
            (fun acc bi_mod_m (xi, yi) ->
              (* zi is just xi composed with gX *)
              let zi = xi in
              let xi_plus_yi_minus_zi =
                Evaluations.linear_c
                  ~res:tmps.(2)
                  ~evaluations:[xi; yi; zi]
                  ~linear_coeffs:[one; one; mone]
                  ~composition_gx:([0; 0; 1], domain_size)
                  ()
              in
              let acc =
                Evaluations.linear_c
                  ~res:tmps.(1)
                  ~evaluations:[acc; xi_plus_yi_minus_zi]
                  ~linear_coeffs:[one; Scalar.of_z @@ (bi_mod_m %! mj)]
                  ()
              in
              Evaluations.copy ~res:tmps.(0) acc)
            id_mj_without_sum
            bs_mod_m
            (List.combine xs ys)
        in
        let identity =
          Evaluations.mul_c ~res:ids.(i) ~evaluations:[q_mod_add; id_mj] ()
        in
        (prefix @@ q_label ^ "." ^ string_of_int i, identity))
      ((Scalar.order, None) :: List.combine M.moduli_add t_infos)
    |> SMap.of_list

  let verifier_identities ~prefix_common ~prefix ~public:_ ~generator:_
      ~size_domain:_ : verifier_identities =
   fun _ answers ->
    let {q; wires; wires_g} =
      get_answers ~gx:true ~q_label ~prefix ~prefix_common answers
    in
    List.mapi
      (fun i id -> (prefix @@ q_label ^ "." ^ string_of_int i, id))
      (equations ~q ~wires ~wires_g ())
    |> SMap.of_list

  let cs ~q:q_mod_add ~wires ~wires_g ?precomputed_advice:_ () =
    (* z = (x + y) mod m
       let k := nb_limbs and n := |moduli|

       PlonK wires distribution:
        row i   : x0 ... x_{k-1} y0 ... y_{k-1}
        row i+1 : z0 ... z_{k-1} qm t1 ... t_{n}
    *)
    let open L in
    let xs, ys, zs, qm, t_infos = get_values wires wires_g in
    let* zero = Num.zero in
    map2M
      (fun mj t_info ->
        (* \sum_i ((B^i mod m) mod mj) * (x_i + y_i - z_i)
           - qm * (m mod mj) - ((qm_shift * m) mod mj) = (tj + tj_shift) * mj *)
        let tj, tj_shift =
          match t_info with
          | Some (tj, tj_shift) -> (tj, tj_shift)
          | None -> (zero, Z.zero)
        in
        let* id_mj =
          let sum_terms =
            List.map2
              (fun bi_mod_m ((xi, yi), zi) ->
                let c = Scalar.of_z (bi_mod_m %! mj) in
                [(c, xi); (c, yi); (Scalar.negate c, zi)])
              bs_mod_m
              (List.combine (List.combine xs ys) zs)
            |> List.concat
          in
          let qc =
            Scalar.of_z Z.(-(qm_shift * M.modulus %! mj) - (tj_shift * mj))
          in
          let coeffs, vars =
            List.split
            @@ [
                 Scalar.(negate @@ of_z (M.modulus %! mj), qm);
                 Scalar.(negate @@ of_z mj, tj);
               ]
            @ sum_terms
          in
          Num.add_list ~qc ~coeffs (to_list vars)
        in
        Num.mul q_mod_add id_mj)
      (Scalar.order :: M.moduli_add)
      (None :: t_infos)
end

(* Modular multiplication over a non-native modulus:
   Non Arith
   degree : 3
   nb identities : 1 + |MOD_ARITH.moduli_mul|
   advice selectors : None
   equations : see lib_plompiler/gadget_mod_arith.ml *)
module Make_ModMul (MOD_ARITH : Plompiler__Gadget_mod_arith.MOD_ARITH) :
  Base_sig = struct
  module M = MOD_ARITH (L)

  let q_label = "q_mod_mul_" ^ M.label

  let ( %! ) = Z.rem

  (* In the first row:
     M.nb_limbs correspond to the first input.
     M.nb_limbs correspond to the second input.

     In the second row:
     M.nb_limbs correspond to the output.
     1 corresponds to qm variable (quotient by the main modulus).
     M.moduli_mul correspond to tj (quotionts by the auxiliary moduli). *)
  let nb_used_wires =
    let used_fst_row = 2 * M.nb_limbs in
    let used_snd_row = M.nb_limbs + 1 + List.length M.moduli_mul in
    let nb_used_wires = Int.max used_fst_row used_snd_row in
    assert (nb_used_wires <= Plompiler.Csir.nb_wires_arch) ;
    nb_used_wires

  (* powers of the base modulo the modulus *)
  let bs_mod_m =
    List.init M.nb_limbs (fun i -> Z.pow M.base i %! M.modulus) |> List.rev

  let bij_mod_m =
    List.init M.nb_limbs (fun i ->
        List.init M.nb_limbs (fun j -> Z.pow M.base (i + j) %! M.modulus))
    |> List.concat |> List.rev

  let (qm_shift, _), ts_bounds = M.bounds_mul

  (* There are as many identities as moduli + 1, as we also have an identity
     on the native modulus *)
  let identity = (q_label, 1 + List.length M.moduli_mul)

  let index_com = None

  let nb_advs = 0

  let nb_buffers = 3

  let gx_composition = true

  let polynomials_degree =
    (q_label, 3) :: List.init nb_used_wires (fun i -> (wire_name i, 3))
    |> SMap.of_list

  let get_values wires wires_g =
    let xs = List.init M.nb_limbs (fun i -> wires.(i)) in
    let ys = List.init M.nb_limbs (fun i -> wires.(M.nb_limbs + i)) in
    let zs = List.init M.nb_limbs (fun i -> wires_g.(i)) in
    let qm = wires_g.(M.nb_limbs) in
    let ts = List.mapi (fun i _ -> wires_g.(M.nb_limbs + 1 + i)) M.moduli_mul in
    let t_infos =
      List.map2 (fun tj (t_shift, _) -> Some (tj, t_shift)) ts ts_bounds
    in
    (xs, ys, zs, qm, t_infos)

  let equations ~q:q_mod_mul ~wires ~wires_g ?precomputed_advice:_ () =
    (* z = (x * y) mod m
       let k := nb_limbs and n := |moduli|

       PlonK wires distribution:
        row i   : x0 ... x_{k-1} y0 ... y_{k-1}
        row i+1 : z0 ... z_{k-1} qm t1 ... t_{n}
    *)
    let xs, ys, zs, qm, t_infos = get_values wires wires_g in
    let sum = List.fold_left Scalar.add Scalar.zero in
    let x_times_y =
      List.concat_map (fun xi -> List.map (fun yj -> Scalar.(xi * yj)) ys) xs
    in
    List.map2
      (fun mj t_info ->
        (* \sum_i (\sum_j (B^{i+j} mod m) * x_i * y_j) - (\sum_i (B^i mod m) * z_i)
           - qm * (m mod mj) - ((qm_shift * m) mod mj) = (tj + tj_shift) * mj *)
        let tj, tj_shift =
          match t_info with
          | Some (tj, tj_shift) -> (tj, tj_shift)
          | None -> (Scalar.zero, Z.zero)
        in
        let id_mj =
          let open Scalar in
          let mod_mj v = v %! mj |> of_z in
          let bs_mod_m_mod_mj = List.map mod_mj bs_mod_m in
          let bij_mod_m_mod_mj = List.map mod_mj bij_mod_m in
          let sum_xy = sum @@ List.map2 Scalar.mul bij_mod_m_mod_mj x_times_y in
          let sum_z = sum @@ List.map2 Scalar.mul bs_mod_m_mod_mj zs in
          sum_xy + negate sum_z
          + negate (qm * mod_mj M.modulus)
          + negate (mod_mj Z.(qm_shift * M.modulus))
          + negate ((tj + of_z tj_shift) * of_z mj)
        in
        Scalar.(q_mod_mul * id_mj))
      (Scalar.order :: M.moduli_mul)
      (None :: t_infos)

  let prover_identities ~prefix_common ~prefix ~public:_ ~domain :
      prover_identities =
   fun evaluations ->
    let domain_size = Domain.length domain in
    let tmps, ids = get_buffers ~nb_buffers ~nb_ids:(snd identity) in
    let ({q; wires} : witness) =
      get_evaluations ~q_label ~prefix ~prefix_common evaluations
    in
    let q_mod_mul = q in
    (* Note that in the prover we do not have wires_g, so we will need to
       compose qm & ts with gX *)
    let xs, ys, _zs, qm, t_infos = get_values wires wires in
    (* zi is just xi composed with gX *)
    let zs = xs in
    let xy_pairs =
      List.concat_map (fun xi -> List.map (fun yj -> (xi, yj)) ys) xs
    in
    List.mapi
      (fun i (mj, t_info) ->
        (* id_mj :=
           \sum_i (\sum_j (B^{i+j} mod m) * x_i * y_j) - (\sum_i (B^i mod m) * z_i)
           - qm * (m mod mj) - ((qm_shift * m) mod mj) - (tj + tj_shift) * mj *)
        let id_mj_without_sums =
          (* In the case of the native modulus, we can ignore the
             (tj + tj_shift) component *)
          let tj, tj_coeff, tj_shift, tj_comp =
            match t_info with
            | Some (tj, tj_shift) ->
                ([tj], Scalar.[negate (of_z mj)], tj_shift, [1])
            | None -> ([], [], Z.zero, [])
          in
          Evaluations.linear_c
            ~res:tmps.(0)
            ~evaluations:(qm :: tj)
            ~composition_gx:(1 :: tj_comp, domain_size)
            ~linear_coeffs:(Scalar.(negate (of_z (M.modulus %! mj))) :: tj_coeff)
            ~add_constant:
              Scalar.(
                negate (of_z Z.((qm_shift * M.modulus %! mj) + (tj_shift * mj))))
            ()
        in
        let id_mj_without_sum_z =
          List.fold_left2
            (fun acc bij_mod_m (xi, yj) ->
              let xiyj =
                Evaluations.mul_c ~res:tmps.(2) ~evaluations:[xi; yj] ()
              in
              let acc =
                Evaluations.linear_c
                  ~res:tmps.(1)
                  ~evaluations:[acc; xiyj]
                  ~linear_coeffs:[one; Scalar.of_z (bij_mod_m %! mj)]
                  ()
              in
              Evaluations.copy ~res:tmps.(0) acc)
            id_mj_without_sums
            bij_mod_m
            xy_pairs
        in
        let id_mj =
          List.fold_left2
            (fun acc bi_mod_m zi ->
              let acc =
                Evaluations.linear_c
                  ~res:tmps.(1)
                  ~evaluations:[acc; zi]
                  ~linear_coeffs:[one; Scalar.(negate @@ of_z (bi_mod_m %! mj))]
                  ~composition_gx:([0; 1], domain_size)
                  ()
              in
              Evaluations.copy ~res:tmps.(0) acc)
            id_mj_without_sum_z
            bs_mod_m
            zs
        in
        let identity =
          Evaluations.mul_c ~res:ids.(i) ~evaluations:[q_mod_mul; id_mj] ()
        in
        (prefix @@ q_label ^ "." ^ string_of_int i, identity))
      ((Scalar.order, None) :: List.combine M.moduli_mul t_infos)
    |> SMap.of_list

  let verifier_identities ~prefix_common ~prefix ~public:_ ~generator:_
      ~size_domain:_ : verifier_identities =
   fun _ answers ->
    let {q; wires; wires_g} =
      get_answers ~gx:true ~q_label ~prefix ~prefix_common answers
    in
    List.mapi
      (fun i id -> (prefix @@ q_label ^ "." ^ string_of_int i, id))
      (equations ~q ~wires ~wires_g ())
    |> SMap.of_list

  let cs ~q:q_mod_mul ~wires ~wires_g ?precomputed_advice:_ () =
    (* z = (x * y) mod m
       let k := nb_limbs and n := |moduli|

       PlonK wires distribution:
        row i   : x0 ... x_{k-1} y0 ... y_{k-1}
        row i+1 : z0 ... z_{k-1} qm t1 ... t_{n}
    *)
    let open L in
    let xs, ys, zs, qm, t_infos = get_values wires wires_g in
    let xy_pairs =
      List.concat_map (fun xi -> List.map (fun yj -> (xi, yj)) ys) xs
    in
    let* xys = mapM (fun (xi, yj) -> Num.mul xi yj) xy_pairs in
    let* zero = Num.constant Scalar.zero in
    map2M
      (fun mj t_info ->
        (* \sum_i (\sum_j (B^{i+j} mod m) * x_i * y_j) - (\sum_i (B^i mod m) * z_i)
           - qm * (m mod mj) - ((qm_shift * m) mod mj) = (tj + tj_shift) * mj *)
        let tj, tj_shift =
          match t_info with
          | Some (tj, tj_shift) -> (tj, tj_shift)
          | None -> (zero, Z.zero)
        in
        let* id_mj =
          let xy_terms =
            List.map2
              (fun bij_mod_m xiyj -> (Scalar.of_z (bij_mod_m %! mj), xiyj))
              bij_mod_m
              xys
          in
          let z_terms =
            List.map2
              (fun bi_mod_m zi ->
                (Scalar.(negate @@ of_z (bi_mod_m %! mj)), zi))
              bs_mod_m
              zs
          in
          let sum_terms = xy_terms @ z_terms in
          let qc =
            Scalar.of_z Z.(-(qm_shift * M.modulus %! mj) - (tj_shift * mj))
          in
          let coeffs, vars =
            List.split
            @@ [
                 Scalar.(negate @@ of_z (M.modulus %! mj), qm);
                 Scalar.(negate @@ of_z mj, tj);
               ]
            @ sum_terms
          in
          Num.add_list ~qc ~coeffs (to_list vars)
        in
        Num.mul q_mod_mul id_mj)
      (Scalar.order :: M.moduli_mul)
      (None :: t_infos)
end

module AddMod25519 = Make_ModAdd (Plompiler.ArithMod25519)
module MulMod25519 = Make_ModMul (Plompiler.ArithMod25519)
module AddMod64 = Make_ModAdd (Plompiler.ArithMod64)
module MulMod64 = Make_ModMul (Plompiler.ArithMod64)
back to top