Revision d0a99151704ed9575dbe9d8422ed25f86972bbc3 authored by Marge Bot on 19 January 2024, 11:29:07 UTC, committed by Marge Bot on 19 January 2024, 11:29:07 UTC
Co-authored-by: Eugen Zalinescu <eugen.zalinescu@nomadic-labs.com>

Approved-by: Raphaƫl Cauderlier <raphael.cauderlier@nomadic-labs.com>
Approved-by: Mohamed IGUERNLALA <iguer@functori.com>

See merge request https://gitlab.com/tezos/tezos/-/merge_requests/11544
2 parent s ca356f6 + 9073022
Raw File
fixed_point_transform.ml
(*****************************************************************************)
(*                                                                           *)
(* Open Source License                                                       *)
(* Copyright (c) 2020 Nomadic Labs, <contact@nomadic-labs.com>               *)
(* Copyright (c) 2023 DaiLambda, Inc., <contact@dailambda.jp>                *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

(* Transform multiplications by constants in a costlang expression to fixed
   point arithmetic. Allows to make cost functions protocol-compatible. *)

(* Modes of casting of float to int *)
type cast_mode = Ceil | Floor | Round

(* Parameters for conversion to fixed point *)
type options = {
  precision : int;
  max_relative_error : float;
  cast_mode : cast_mode;
  inverse_scaling : int;
  resolution : int;
}

(* Handling bad floating point values.  *)
type fp_error = Bad_fpclass of Float.fpclass | Negative_or_zero_fp

(* Handling codegen errors. *)
type fixed_point_transform_error = Term_is_not_closed of Free_variable.t

exception Bad_floating_point_number of fp_error

exception Fixed_point_transform_error of fixed_point_transform_error

(* ------------------------------------------------------------------------- *)

let default_options =
  {
    precision = 6;
    max_relative_error = 0.1;
    cast_mode = Round;
    inverse_scaling = 10;
    resolution = 5;
  }

(* ------------------------------------------------------------------------- *)
(* Printers, encodings, etc. *)

let pp_fixed_point_transform_error fmtr (err : fixed_point_transform_error) =
  match err with
  | Term_is_not_closed s ->
      Format.fprintf
        fmtr
        "Fixed_point_transform: Term is not closed (free variable %a \
         encountered)"
        Free_variable.pp
        s

let cast_mode_encoding =
  let open Data_encoding in
  union
    [
      case
        ~title:"Ceil"
        (Tag 0)
        (constant "Ceil")
        (function Ceil -> Some () | _ -> None)
        (fun () -> Ceil);
      case
        ~title:"Floor"
        (Tag 1)
        (constant "Floor")
        (function Floor -> Some () | _ -> None)
        (fun () -> Floor);
      case
        ~title:"Round"
        (Tag 2)
        (constant "Round")
        (function Round -> Some () | _ -> None)
        (fun () -> Round);
    ]

let options_encoding =
  let open Data_encoding in
  conv
    (fun {precision; max_relative_error; cast_mode; inverse_scaling; resolution} ->
      (precision, max_relative_error, cast_mode, inverse_scaling, resolution))
    (fun (precision, max_relative_error, cast_mode, inverse_scaling, resolution) ->
      {precision; max_relative_error; cast_mode; inverse_scaling; resolution})
    (obj5
       (dft "precision" int31 default_options.precision)
       (dft "max_relative_error" float default_options.max_relative_error)
       (dft "cast_mode" cast_mode_encoding default_options.cast_mode)
       (dft "inverse_scaling" int31 default_options.inverse_scaling)
       (dft "resolution" int31 default_options.resolution))

(* ------------------------------------------------------------------------- *)
(* Error registration *)

let () =
  Printexc.register_printer (fun exn ->
      match exn with
      | Bad_floating_point_number error ->
          let s =
            match error with
            | Bad_fpclass fpcl -> (
                match fpcl with
                | FP_subnormal -> "FP_subnormal"
                | FP_infinite -> "FP_infinite"
                | FP_nan -> "FP_nan"
                | _ -> assert false)
            | Negative_or_zero_fp -> "<= 0"
          in
          Some
            (Printf.sprintf
               "Fixed_point_transform: Bad floating point number: %s"
               s)
      | Fixed_point_transform_error err ->
          let s = Format.asprintf "%a" pp_fixed_point_transform_error err in
          Some s
      | _ -> None)

(* ------------------------------------------------------------------------- *)
(* Constant prettification *)

let rec log10 x =
  if x <= 0 then invalid_arg "log10"
  else if x <= 10 then 1
  else 1 + log10 (x / 10)

let rec pow x n =
  if n < 0 then invalid_arg "pow"
  else if n = 0 then 1
  else if n = 1 then x
  else x * pow x (n - 1)

let snap_to_grid ~inverse_scaling ~resolution x =
  if x = 0 then 0
  else
    let not_significant = log10 x / inverse_scaling in
    let grid = resolution * pow 10 not_significant in
    (x + grid - 1) / grid * grid

(* ------------------------------------------------------------------------- *)
(* Helpers *)

let int_of_float mode x =
  match mode with
  | Ceil -> int_of_float (Float.ceil x)
  | Floor -> int_of_float (Float.floor x)
  | Round -> int_of_float (Float.round x)

(* Checks that a floating point number is 'good' *)
let assert_fp_is_correct (x : float) =
  let fpcl = Float.classify_float x in
  match fpcl with
  | FP_subnormal | FP_infinite | FP_nan ->
      raise (Bad_floating_point_number (Bad_fpclass fpcl))
  | FP_normal when x <= 0.0 ->
      raise (Bad_floating_point_number Negative_or_zero_fp)
  | _ -> fpcl

let cast_to_int max_relative_error mode f : int =
  let i = int_of_float mode f in
  let fi = float_of_int i in
  let re = abs_float (f -. fi) /. abs_float f in
  if re > max_relative_error then
    Format.eprintf
      "Warning: Fixed_point_transform: Imprecise integer cast of %f to %d: %f \
       %% relative error@."
      f
      i
      (re *. 100.) ;
  i

(* ------------------------------------------------------------------------- *)

(* A minimal language in which to perform fixed-point multiplication of
   a 'size repr' by a float *)
module type Fixed_point_lang_sig = sig
  type 'a repr

  type size

  val shift_right : size repr -> int -> size repr

  val ( + ) : size repr -> size repr -> size repr

  val ( * ) : size repr -> size repr -> size repr

  val int : int -> size repr
end

module Fixed_point_arithmetic (Lang : Fixed_point_lang_sig) : sig
  (** [approx_mult precision i f] generates fixed-precision multiplication
      of [i * f] by positive constants. [precision] is a paramter to control
      how many bit shifts are used.
  *)
  val approx_mult :
    cast_mode -> int -> Lang.size Lang.repr -> float -> Lang.size Lang.repr
end = struct
  (* IEEE754 redux
     -------------
     Format of a (full-precision, ie 64 bit) floating point number:
     . 1 bit of sign
     . 11 bits of exponent, implicitly centered on 0 (-1022 to 1023)
     . 52 bits of mantissa (+ 1 implicit)

     value of a fp number = sign * mantissa * 2^{exponent - 1023}
     (with the exceptions of nans, infinities and denormalised numbers) *)

  (* Extract ith bit from a float. *)
  let bit (x : float) (i : int) =
    assert (not (i < 0 || i > 63)) ;
    let bits = Int64.bits_of_float x in
    Int64.(logand (shift_right bits i) one)

  (* All bits of a float:
     all_bits x = [sign] @ exponent @ mantissa *)
  let all_bits (x : float) : int64 list =
    List.init ~when_negative_length:() 64 (fun i -> bit x i)
    |> (* 64 >= 0 *) WithExceptions.Result.get_ok ~loc:__LOC__
    |> List.rev

  (* take n first elements from a list *)
  let take n l =
    let rec take n l acc =
      if n <= 0 then (List.rev acc, l)
      else
        match l with
        | [] -> Stdlib.failwith "take"
        | hd :: tl -> take (n - 1) tl (hd :: acc)
    in
    take n l []

  (* Split a float into sign/exponent/mantissa *)
  let split bits =
    let sign, rest = take 1 bits in
    let expo, rest = take 11 rest in
    let mant, _ = take 52 rest in
    (sign, expo, mant)

  (* Convert bits of exponent to int. *)
  let exponent_bits_to_int (l : int64 list) =
    let rec exponent_to_int (l : int64 list) (index : int) : int64 =
      match l with
      | [] -> -1023L
      | bit :: tail ->
          let tail = exponent_to_int tail (index + 1) in
          Int64.(add (shift_left bit index) tail)
    in
    exponent_to_int (List.rev l) 0

  (* Decompose a float into sign/exponent/mantissa *)
  let decompose (x : float) = split (all_bits x)

  let increment_bits exp bits =
    let rec f = function
      | [] -> (true, [])
      | 0L :: rest ->
          let up, rest = f rest in
          (false, (if up then 1L else 0L) :: rest)
      | 1L :: rest ->
          let up, rest = f rest in
          if up then (true, 0L :: rest) else (false, 1L :: rest)
      | _ -> assert false
    in
    let up, bits = f bits in
    if up then (exp + 1, 1L :: bits) else (exp, bits)

  (* Generate fixed-precision multiplication by positive constants. *)
  let approx_mult mode (precision : int) (i : Lang.size Lang.repr) (x : float) :
      Lang.size Lang.repr =
    assert (precision > 0) ;
    let fpcl = assert_fp_is_correct x in
    match fpcl with
    | FP_zero -> Lang.int 0
    | _ ->
        let _sign, exp, mant = decompose x in
        let exp = Int64.to_int @@ exponent_bits_to_int exp in
        (* The mantissa is always implicitly prefixed by one (except for
           denormalized numbers, excluded here). *)
        let bits = 1L :: mant in
        (* Get the top [precision] bits *)
        let bits, rest = take precision bits in
        (* Rounding. [bits_rounded] has [precision+1] bits at most.
           The number of ones in it is at most [precision] *)
        let exp, bits_rounded =
          match mode with
          | Ceil ->
              if List.for_all (fun x -> x = 0L) rest then (exp, bits)
              else increment_bits exp bits
          | Floor -> (exp, bits)
          | Round -> (
              match rest with
              | 1L :: _ -> increment_bits exp bits
              | [] | 0L :: _ -> (exp, bits)
              | _ -> assert false)
        in
        (* convert bits for < 1.0 to sum of powers of 2 computed with shifts *)
        let _, integer, fracs =
          List.fold_left
            (fun (k, integer, fracs) bit ->
              let integer, fracs =
                if bit = 1L then
                  if exp - k < 0 then
                    (integer, Lang.shift_right i (k - exp) :: fracs)
                  else (integer + (1 lsl (exp - k)), fracs)
                else (integer, fracs)
              in
              (k + 1, integer, fracs))
            (0, 0, [])
            bits_rounded
        in
        if integer = 0 then
          match List.rev fracs with
          | [] -> assert false
          | f :: fracs -> List.fold_left (fun sum t -> Lang.(sum + t)) f fracs
        else
          List.fold_left
            (fun t sum -> Lang.(sum + t))
            (if integer = 1 then i else Lang.(i * int integer))
            (List.rev fracs)
end

(* [Convert_mult] approximates [float] values to integers:

   - The multiplications of the form [float * term] or [term * float]
     to integer-only expressions.
   - [float] constants to its nearest grid point

   It is assumed that the term is _closed_, i.e. contains no free variables.
*)
module Convert_mult (P : sig
  val options : options
end)
(X : Costlang.S) : sig
  include Costlang.S with type size = X.size

  val prj : 'a repr -> 'a X.repr
end = struct
  type size = X.size

  let size_ty = X.size_ty

  type 'a repr = Term : 'a X.repr -> 'a repr | Float : float -> X.size repr

  let {precision; max_relative_error; cast_mode; inverse_scaling; resolution} =
    P.options

  module FPA = Fixed_point_arithmetic (X)

  (* Cast to int then snap to the nearest grid point *)
  let cast_and_snap f =
    X.int
    @@ snap_to_grid ~inverse_scaling ~resolution
    @@ cast_to_int max_relative_error cast_mode f

  (* Any float left is converted to the nearest grid point *)
  let prj (type a) (term : a repr) : a X.repr =
    match term with Term t -> t | Float f -> cast_and_snap f

  (* By default, any float is converted to the nearest grid point *)
  let lift_unop op x =
    match x with
    | Term x -> Term (op x)
    | Float x -> Term (op @@ cast_and_snap x)

  (* By default, any float is converted to the nearest grid point *)
  let lift_binop op x y =
    match (x, y) with
    | Term x, Term y -> Term (op x y)
    | Term x, Float y -> Term (op x (cast_and_snap y))
    | Float x, Term y -> Term (op (cast_and_snap x) y)
    | Float x, Float y -> Term (op (cast_and_snap x) (cast_and_snap y))

  let gensym : unit -> string =
    let x = ref 0 in
    fun () ->
      let v = !x in
      incr x ;
      "v" ^ string_of_int v

  let false_ = Term X.false_

  let true_ = Term X.true_

  let float f = Float f

  (* Integers are kept as they are *)
  let int i = Term (X.int i)

  let ( + ) = lift_binop X.( + )

  let sat_sub = lift_binop X.sat_sub

  let ( * ) x y =
    match (x, y) with
    | Term x, Term y -> Term X.(x * y)
    | Term x, Float y | Float y, Term x ->
        (* let-bind the non-constant term [x] to avoid copying it. *)
        Term
          (X.let_ ~name:(gensym ()) x (fun x ->
               FPA.approx_mult Ceil precision x y))
    | Float x, Float y -> Float (x *. y)

  let ( / ) = lift_binop X.( / )

  let max = lift_binop X.max

  let min = lift_binop X.min

  let shift_left x i = lift_unop (fun x -> X.shift_left x i) x

  let shift_right x i = lift_unop (fun x -> X.shift_right x i) x

  let log2 = lift_unop X.log2

  let sqrt = lift_unop X.sqrt

  let free ~name = raise (Fixed_point_transform_error (Term_is_not_closed name))

  let lt = lift_binop X.lt

  let eq = lift_binop X.eq

  let lam' (type a b) ~name (ty : a Costlang.Ty.t) (f : a repr -> b repr) :
      (a -> b) repr =
    Term
      (X.lam' ~name ty (fun x ->
           match f (Term x) with Term y -> y | Float f -> cast_and_snap f))

  let lam ~name = lam' ~name size_ty

  let app (type a b) (fn : (a -> b) repr) (arg : a repr) : b repr =
    match (fn, arg) with
    | Term fn, Term arg -> Term (X.app fn arg)
    | Term fn, Float f -> Term (X.app fn (cast_and_snap f))
    | Float _, _ -> assert false

  let let_ (type a b) ~name (m : a repr) (fn : a repr -> b repr) : b repr =
    match m with
    | Term m ->
        Term
          (X.let_ ~name m (fun x ->
               match fn (Term x) with Term y -> y | Float f -> cast_and_snap f))
    | Float f ->
        Term
          (X.let_ ~name (cast_and_snap f) (fun x ->
               match fn (Term x) with Term y -> y | Float f -> cast_and_snap f))

  let if_ cond ift iff = Term (X.if_ (prj cond) (prj ift) (prj iff))
end

module Apply (P : sig
  val options : options
end) : Costlang.Transform =
functor (X : Costlang.S) -> Convert_mult (P) (X)
back to top