Raw File
linear_algebra.ml
(** Linear algebra module, copied from {{:
    https://gitlab.com/nomadic-labs/privacy-team/-/blob/9e4050cb4a304848901c3434d61a8d7f0c7107c4/nuplompiler/linear_algebra.ml
    } Nomadic Labs privacy team repository } *)
module type Ring_sig = sig
  type t

  val add : t -> t -> t

  val mul : t -> t -> t

  val negate : t -> t

  val zero : t

  val one : t

  val eq : t -> t -> bool
end

module type Field_sig = sig
  include Ring_sig

  val inverse_exn : t -> t
end

(** This refers to the mathematical generalization of vector space called
    "module", where the field of scalars is replaced by a ring *)
module type Module_sig = sig
  type t

  type matrix = t array array

  (** [zeros r c] is a matrix with [r] rows and [c] columns filled with zeros *)
  val zeros : int -> int -> matrix

  (** [identity n] is the identity matrix of dimension [n] *)
  val identity : int -> matrix

  (** matrix equality *)
  val equal : matrix -> matrix -> bool

  (** matrix addition *)
  val add : matrix -> matrix -> matrix

  (** matrix multiplication *)
  val mul : matrix -> matrix -> matrix

  (** matrix transposition *)
  val transpose : matrix -> matrix

  (** [row_add ~coeff i j m] adds to the i-th row, the j-th row times coeff in m *)
  val row_add : ?coeff:t -> int -> int -> matrix -> unit

  (** [row_swap i j m] swaps the i-th and j-th rows of m *)
  val row_swap : int -> int -> matrix -> unit

  (** [row_mul coeff i m] multiplies the i-th row by coeff in m *)
  val row_mul : t -> int -> matrix -> unit

  (** [filter_cols f m] removes the columns of [m] whose index does not satisfy [f] *)
  val filter_cols : (int -> bool) -> matrix -> matrix

  (** splits matrix [m] into the first n columns and the rest, producing two matrices *)
  val split_n : int -> matrix -> matrix * matrix
end

module type VectorSpace_sig = sig
  include Module_sig

  (** reduced row Echelon form of m *)
  val reduced_row_echelon_form : matrix -> matrix

  (** [inverse m] is the inverse matrix of m

      @raise [Invalid_argument] if [m] is not invertible *)
  val inverse : matrix -> matrix
end

module Make_Module (Ring : Ring_sig) : Module_sig with type t = Ring.t = struct
  type t = Ring.t

  type matrix = t array array

  let zeros r c = Array.make_matrix r c Ring.zero

  let identity n =
    Array.(init n (fun i -> init n Ring.(fun j -> if i = j then one else zero)))

  let equal = Array.(for_all2 (for_all2 Ring.eq))

  let add = Array.(map2 (map2 Ring.add))

  let mul m1 m2 =
    let nb_rows = Array.length m1 in
    let nb_cols = Array.length m2.(0) in
    let n = Array.length m1.(0) in
    assert (Array.length m2 = n) ;
    let p = zeros nb_rows nb_cols in
    for i = 0 to nb_rows - 1 do
      for j = 0 to nb_cols - 1 do
        for k = 0 to n - 1 do
          p.(i).(j) <- Ring.(add p.(i).(j) @@ mul m1.(i).(k) m2.(k).(j))
        done
      done
    done ;
    p

  let transpose m =
    let nb_rows = Array.length m in
    let nb_cols = Array.length m.(0) in
    Array.(init nb_cols (fun i -> init nb_rows (fun j -> m.(j).(i))))

  let row_add ?(coeff = Ring.one) i j m =
    m.(i) <- Array.map2 Ring.(fun a b -> add a (mul coeff b)) m.(i) m.(j)

  let row_swap i j m =
    let aux = m.(i) in
    m.(i) <- m.(j) ;
    m.(j) <- aux

  let row_mul coeff i m = m.(i) <- Array.map (Ring.mul coeff) m.(i)

  let filter_cols f =
    Array.map (fun row ->
        List.filteri (fun i _ -> f i) (Array.to_list row) |> Array.of_list)

  let split_n n m =
    (filter_cols (fun i -> i < n) m, filter_cols (fun i -> i >= n) m)
end

module Make_VectorSpace (Field : Field_sig) :
  VectorSpace_sig with type t = Field.t = struct
  include Make_Module (Field)

  let reduced_row_echelon_form m =
    let n = Array.length m in
    (* returns the first non-zero index in the row *)
    let find_pivot row =
      let rec aux cnt = function
        | [] -> None
        | x :: xs -> if Field.(eq zero x) then aux (cnt + 1) xs else Some cnt
      in
      aux 0 (Array.to_list row)
    in
    let move_zeros_to_bottom m =
      let is_non_zero_row = Array.exists (fun a -> not Field.(eq zero a)) in
      let rec aux nonzeros zeros = function
        | [] -> Array.of_list (List.rev nonzeros @ zeros)
        | r :: rs ->
            if is_non_zero_row r then aux (r :: nonzeros) zeros rs
            else aux nonzeros (r :: zeros) rs
      in
      aux [] [] (Array.to_list m)
    in
    let rec aux k =
      if k >= Array.length m then m
      else
        match find_pivot m.(k) with
        | Some j when j < n ->
            row_mul (Field.inverse_exn m.(k).(j)) k m ;
            Array.iteri
              (fun i _ ->
                if i <> k then row_add ~coeff:Field.(negate @@ m.(i).(j)) i k m)
              m ;
            row_swap k j m ;
            aux (k + 1)
        | _ -> aux (k + 1)
    in
    aux 0 |> move_zeros_to_bottom

  let inverse m =
    let n = Array.length m in
    let id_n = identity n in
    let augmented = Array.(map2 append m id_n) in
    let reduced = reduced_row_echelon_form augmented in
    let residue, inv = split_n n reduced in
    let is_zero_row = Array.for_all Field.(eq zero) in
    if Array.exists is_zero_row residue then
      raise @@ Invalid_argument "matrix [m] is not invertible"
    else inv
end
back to top