Revision 322c73366a9198d5bd6be08e91b729c775761821 authored by Diane Gallois-Wong on 31 August 2022, 15:57:02 UTC, committed by Marge Bot on 06 September 2022, 08:21:04 UTC
Notably, remove plugin tests on 1M, since the plugin is no longer
responsible for enforcing 1M. Similar tests on 1M already exist
in tezt, and will be extended in the next commit to cover all
the cases of the removed tests.
1 parent 995112f
Raw File
costlang.ml
(*****************************************************************************)
(*                                                                           *)
(* Open Source License                                                       *)
(* Copyright (c) 2019 Nomadic Labs. <contact@nomadic-labs.com>               *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

module type S = sig
  type 'a repr

  type size

  val true_ : bool repr

  val false_ : bool repr

  val int : int -> size repr

  val float : float -> size repr

  val ( + ) : size repr -> size repr -> size repr

  val ( - ) : size repr -> size repr -> size repr

  val ( * ) : size repr -> size repr -> size repr

  val ( / ) : size repr -> size repr -> size repr

  val max : size repr -> size repr -> size repr

  val min : size repr -> size repr -> size repr

  val log2 : size repr -> size repr

  val sqrt : size repr -> size repr

  val free : name:Free_variable.t -> size repr

  val lt : size repr -> size repr -> bool repr

  val eq : size repr -> size repr -> bool repr

  val shift_left : size repr -> int -> size repr

  val shift_right : size repr -> int -> size repr

  val lam : name:string -> ('a repr -> 'b repr) -> ('a -> 'b) repr

  val app : ('a -> 'b) repr -> 'a repr -> 'b repr

  val let_ : name:string -> 'a repr -> ('a repr -> 'b repr) -> 'b repr

  val if_ : bool repr -> 'a repr -> 'a repr -> 'a repr
end

(* ------------------------------------------------------------------------- *)
(* Various useful implementations of the signatures above. *)

module Pp : S with type 'a repr = string and type size = string = struct
  type 'a repr = string

  type size = string

  let true_ = "true"

  let false_ = "false"

  let float = string_of_float

  let int = string_of_int

  let ( + ) x y = Format.asprintf "(%s + %s)" x y

  let ( - ) x y = Format.asprintf "(%s - %s)" x y

  let ( * ) x y = Format.asprintf "(%s * %s)" x y

  let ( / ) x y = Format.asprintf "(%s / %s)" x y

  let max x y = Format.asprintf "(max %s %s)" x y

  let min x y = Format.asprintf "(min %s %s)" x y

  let shift_left x i = Format.asprintf "(%s lsl %d)" x i

  let shift_right x i = Format.asprintf "(%s lsr %d)" x i

  let log2 x = Format.asprintf "(log2 %s)" x

  let sqrt x = Format.asprintf "(sqrt %s)" x

  let free ~name = Format.asprintf "free(%a)" Free_variable.pp name

  let lt x y = Format.asprintf "(%s < %s)" x y

  let eq x y = Format.asprintf "(%s = %s)" x y

  let lam ~name f = Format.asprintf "fun %s -> %s" name (f name)

  let app f arg = Format.asprintf "(%s) %s" f arg

  let let_ ~name m f = Format.asprintf "let %s = %s in %s" name m (f name)

  let if_ cond ift iff = Format.asprintf "(if %s then %s else %s)" cond ift iff
end

module Free_variables :
  S with type 'a repr = Free_variable.Set.t and type size = unit = struct
  open Free_variable

  exception Free_variable_captured_by_lambda of string

  exception Free_variable_captured_by_let of string

  type 'a repr = Set.t

  type size = unit

  let lift_binop x y = Set.union x y

  let true_ = Set.empty

  let false_ = Set.empty

  let float _ = Set.empty

  let int _ = Set.empty

  let ( + ) = lift_binop

  let ( - ) = lift_binop

  let ( * ) = lift_binop

  let ( / ) = lift_binop

  let max = lift_binop

  let min = lift_binop

  let shift_left x _i = x

  let shift_right x _i = x

  let log2 x = x

  let sqrt x = x

  let free ~name = Set.singleton name

  let lt = lift_binop

  let eq = lift_binop

  let lam ~name f =
    let result = f Set.empty in
    let bound = Free_variable.of_string name in
    if Set.mem bound result then raise (Free_variable_captured_by_lambda name)
    else result

  let app f arg = Set.union f arg

  let let_ ~name m f =
    let in_scope = f Set.empty in
    let result = Set.union m in_scope in
    let bound = Free_variable.of_string name in
    if Set.mem bound in_scope then raise (Free_variable_captured_by_let name)
    else result

  let if_ cond ift iff = Set.union cond (Set.union ift iff)
end

module Eval : S with type 'a repr = 'a and type size = float = struct
  exception Term_contains_free_variable of Free_variable.t

  type 'a repr = 'a

  type size = float

  let lift_binop op x y = op x y

  let true_ = true

  let false_ = false

  let float x = x

  let int x = float_of_int x

  let ( + ) = lift_binop ( +. )

  let ( - ) = lift_binop ( -. )

  let ( * ) = lift_binop ( *. )

  let ( / ) = lift_binop ( /. )

  let max = lift_binop max

  let min = lift_binop min

  let shift_left x i = x *. (2. ** float_of_int i)

  let shift_right x i = x /. (2. ** float_of_int i)

  let log2 x = log x /. log 2.

  let sqrt = sqrt

  let free ~name = raise (Term_contains_free_variable name)

  let lt x y = x < y

  let eq x y = x = y

  let lam ~name f =
    ignore name ;
    f

  let app f arg = f arg

  let let_ ~name m f =
    ignore name ;
    f m

  let if_ cond ift iff = if cond then ift else iff
end

(* Evaluating implementation. Expects terms to evaluate
   to affine combinations with free variables as coefficients.
   Fails otherwise.
   Takes a substitution as a parameter. *)

type affine = {linear_comb : Free_variable.Sparse_vec.t; const : float}

module Affine_ops = struct
  module V = Free_variable.Sparse_vec

  let is_const a = V.is_empty a.linear_comb

  let ( + ) a1 a2 =
    {
      linear_comb = V.add a1.linear_comb a2.linear_comb;
      const = a1.const +. a2.const;
    }

  let ( - ) a1 a2 =
    {
      linear_comb = V.add a1.linear_comb (V.neg a2.linear_comb);
      const = a1.const -. a2.const;
    }

  let smul c {linear_comb; const} =
    {linear_comb = V.smul c linear_comb; const = c *. const}
end

(* Substitution for free variables *)
type subst = Free_variable.t -> float option

exception Eval_linear_combination of string

let () =
  Printexc.register_printer (fun exn ->
      match exn with
      | Eval_linear_combination s ->
          Some
            (Format.asprintf
               "Eval_linear_combination: cannot convert node %s"
               s)
      | _ -> None)

module Eval_linear_combination_impl : sig
  include S

  val run : subst -> size repr -> affine
end
(* multiset of strings = formal linear combinations with integer coefficients *) =
struct
  type size = float

  type 'a repr = subst -> 'a result

  and 'a result = Affine : affine -> size result | Bool : bool -> bool result

  let true_ _ = Bool true

  let false_ _ = Bool false

  let int i _ =
    Affine {const = float_of_int i; linear_comb = Free_variable.Sparse_vec.zero}

  let float f _ =
    Affine {const = f; linear_comb = Free_variable.Sparse_vec.zero}

  let ( + ) (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    Affine Affine_ops.(a1 + a2)

  let ( - ) (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    Affine Affine_ops.(a1 - a2)

  let ( * ) (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    if Affine_ops.is_const a1 then Affine (Affine_ops.smul a1.const a2)
    else if Affine_ops.is_const a2 then Affine (Affine_ops.smul a2.const a1)
    else raise (Eval_linear_combination "*")

  let ( / ) (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    if Affine_ops.is_const a2 then Affine (Affine_ops.smul (1. /. a2.const) a1)
    else raise (Eval_linear_combination "/")

  let max (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    if Affine_ops.is_const a1 && Affine_ops.is_const a2 then
      Affine
        {
          linear_comb = Free_variable.Sparse_vec.zero;
          const = max a1.const a2.const;
        }
    else raise (Eval_linear_combination "max")

  let min (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    if Affine_ops.is_const a1 && Affine_ops.is_const a2 then
      Affine
        {
          linear_comb = Free_variable.Sparse_vec.zero;
          const = min a1.const a2.const;
        }
    else raise (Eval_linear_combination "min")

  let log2 (x : size repr) subst =
    let (Affine a) = x subst in
    if Affine_ops.is_const a then
      Affine
        {
          linear_comb = Free_variable.Sparse_vec.zero;
          const = log a.const /. log 2.;
        }
    else raise (Eval_linear_combination "log2")

  let sqrt (x : size repr) subst =
    let (Affine a) = x subst in
    if Affine_ops.is_const a then
      Affine {linear_comb = Free_variable.Sparse_vec.zero; const = sqrt a.const}
    else raise (Eval_linear_combination "sqrt")

  let free ~name subst =
    match subst name with
    | Some const -> Affine {const; linear_comb = Free_variable.Sparse_vec.zero}
    | None ->
        Affine
          {
            const = 0.0;
            linear_comb = Free_variable.Sparse_vec.of_list [(name, 1.0)];
          }

  let lt (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    if Affine_ops.is_const a1 && Affine_ops.is_const a2 then
      Bool (a1.const < a2.const)
    else raise (Eval_linear_combination "lt")

  let eq (x1 : size repr) (x2 : size repr) subst =
    let (Affine a1) = x1 subst in
    let (Affine a2) = x2 subst in
    if Affine_ops.is_const a1 && Affine_ops.is_const a2 then
      Bool (a1.const = a2.const)
    else raise (Eval_linear_combination "eq")

  let shift_left _ _ = raise (Eval_linear_combination "shift_left")

  let shift_right _ _ = raise (Eval_linear_combination "shift_right")

  let lam ~name:_ _f _subst = raise (Eval_linear_combination "lambda")

  let app _bound _body _subst = raise (Eval_linear_combination "app")

  let let_ ~name:_ bound body subst = body bound subst

  let if_ (cond : bool repr) (ift : 'a repr) (iff : 'a repr) : 'a repr =
   fun subst ->
    let (Bool b) = cond subst in
    if b then ift subst else iff subst

  let run : subst -> size repr -> affine =
   fun subst repr ->
    let (Affine res) = repr subst in
    res
end

(* ------------------------------------------------------------------------- *)
(* Implementation _transformers_. *)

module type Transform = functor (X : S) -> sig
  include S with type size = X.size

  val prj : 'a repr -> 'a X.repr
end

type transform = (module Transform)

let compose (f : transform) (g : transform) : transform =
  let module F = (val f) in
  let module G = (val g) in
  let module G_circ_F (X : S) = struct
    module FX = F (X)
    module GFX = G (FX)
    include GFX

    let prj term = FX.prj (GFX.prj term)
  end in
  (module G_circ_F)

(* Identity transform *)
module Identity : Transform =
functor
  (X : S)
  ->
  struct
    include X

    let prj x = x
  end

module Subst (P : sig
  val subst : Free_variable.t -> float
end) : Transform =
functor
  (X : S)
  ->
  struct
    include X

    let prj x = x

    let free ~name = X.float (P.subst name)
  end

module Hashtbl = Stdlib.Hashtbl

type 'a hash_consed = {repr : 'a; hash : int; tag : int}

module Hash_cons : Transform =
functor
  (X : S)
  ->
  struct
    type size = X.size

    type 'a repr = 'a X.repr hash_consed

    type unique_term_identifier =
      | Int_tag of {i : int} (* not a tag, actual data! *)
      | Float_tag of {f : float} (* not a tag, actual data! *)
      | Add_tag of int * int
      | Sub_tag of int * int
      | Mul_tag of int * int
      | Div_tag of int * int
      | Max_tag of int * int
      | Min_tag of int * int
      | Log2_tag of int
      | Sqrt_tag of int
      | Free_tag of {name : Free_variable.t}

    let prj {repr; _} = repr

    (* A hashtable for memoizing terms of type `size repr`. We don't
       bother hash-consing the rest: this is the sublanguage were sharing
       is most useful. *)
    let size_table : (int, size repr * unique_term_identifier) Hashtbl.t =
      Hashtbl.create 101

    let fresh =
      let c = ref ~-1 in
      fun () ->
        incr c ;
        !c

    let insert_if_not_present (term_thunk : unit -> size X.repr)
        (uti : unique_term_identifier) =
      let hash = Hashtbl.hash uti in
      match Hashtbl.find_all size_table hash with
      | [] ->
          let hash_consed = {repr = term_thunk (); hash; tag = fresh ()} in
          Hashtbl.add size_table hash (hash_consed, uti) ;
          hash_consed
      | bindings -> (
          match List.find_opt (fun (_, uti') -> uti = uti') bindings with
          | None ->
              let hash_consed = {repr = term_thunk (); hash; tag = fresh ()} in
              Hashtbl.add size_table hash (hash_consed, uti) ;
              hash_consed
          | Some (res, _) -> res)

    let lift2_nohash f x y =
      let hash = -1 in
      {repr = f x.repr y.repr; hash; tag = fresh ()}

    let false_ = {repr = X.false_; hash = -1; tag = fresh ()}

    let true_ = {repr = X.false_; hash = -1; tag = fresh ()}

    let float (f : float) =
      insert_if_not_present (fun () -> X.float f) (Float_tag {f})

    let int (i : int) = insert_if_not_present (fun () -> X.int i) (Int_tag {i})

    let ( + ) x y =
      insert_if_not_present
        X.(fun () -> x.repr + y.repr)
        (Add_tag (x.tag, y.tag))

    let ( - ) x y =
      insert_if_not_present
        X.(fun () -> x.repr - y.repr)
        (Sub_tag (x.tag, y.tag))

    let ( * ) x y =
      insert_if_not_present
        X.(fun () -> x.repr * y.repr)
        (Mul_tag (x.tag, y.tag))

    let ( / ) x y =
      insert_if_not_present
        X.(fun () -> x.repr / y.repr)
        (Div_tag (x.tag, y.tag))

    let max x y =
      insert_if_not_present
        X.(fun () -> max x.repr y.repr)
        (Max_tag (x.tag, y.tag))

    let min x y =
      insert_if_not_present
        X.(fun () -> min x.repr y.repr)
        (Min_tag (x.tag, y.tag))

    let log2 x =
      insert_if_not_present X.(fun () -> log2 x.repr) (Log2_tag x.tag)

    let sqrt x =
      insert_if_not_present X.(fun () -> sqrt x.repr) (Sqrt_tag x.tag)

    let free ~name =
      insert_if_not_present X.(fun () -> free ~name) (Free_tag {name})

    let lt x y = {repr = X.lt x.repr y.repr; hash = -1; tag = fresh ()}

    let eq x y = {repr = X.eq x.repr y.repr; hash = -1; tag = fresh ()}

    (* The functions below are _not_ hash-consed. *)
    let shift_left x i =
      let hash = -1 in
      {repr = X.shift_left x.repr i; hash; tag = fresh ()}

    let shift_right x i =
      let hash = -1 in
      {repr = X.shift_right x.repr i; hash; tag = fresh ()}

    let unlift_fun : type a b. (a repr -> b repr) -> a X.repr -> b X.repr =
     fun f x -> (f {repr = x; hash = -1; tag = fresh ()}).repr

    let lam ~name body =
      {repr = X.lam ~name (unlift_fun body); hash = -1; tag = fresh ()}

    let app f arg = lift2_nohash X.app f arg

    let let_ ~name bound body =
      {
        repr = X.let_ ~name bound.repr (unlift_fun body);
        hash = -1;
        tag = fresh ();
      }

    let if_ cond ift iff =
      {repr = X.if_ cond.repr ift.repr iff.repr; hash = -1; tag = fresh ()}
  end

(* [Beta_normalize] evaluates beta-redexes. *)
module Beta_normalize : Transform =
functor
  (X : S)
  ->
  struct
    type size = X.size

    (* A value is either a lambda that can be statically evaluated
       (case [Static_lam]) or any value that will be
       dynamically evaluated (case [Dynamic]). *)
    type 'a repr =
      | Static_lam : {
          name : string;
          lam : 'a X.repr -> 'b repr;
        }
          -> ('a -> 'b) repr
      | Dynamic : 'a X.repr -> 'a repr

    let dyn (x : 'a X.repr) : 'a repr = Dynamic x

    let rec prj : type a. a repr -> a X.repr =
     fun x ->
      match x with
      | Static_lam {name; lam} -> X.lam ~name (fun arg -> prj (lam arg))
      | Dynamic d -> d

    let lift1 f x = match x with Dynamic d -> dyn (f d) | _ -> assert false

    let lift2 f x y =
      match (x, y) with
      | Dynamic d, Dynamic e -> dyn (f d e)
      | _ -> assert false

    let false_ = dyn X.false_

    let true_ = dyn X.true_

    let float f = dyn (X.float f)

    let int i = dyn (X.int i)

    let ( + ) x y = lift2 X.( + ) x y

    let ( - ) x y = lift2 X.( - ) x y

    let ( * ) x y = lift2 X.( * ) x y

    let ( / ) x y = lift2 X.( / ) x y

    let max x y = lift2 X.max x y

    let min x y = lift2 X.min x y

    let shift_left x i = lift1 (fun x -> X.shift_left x i) x

    let shift_right x i = lift1 (fun x -> X.shift_right x i) x

    let log2 x = lift1 X.log2 x

    let sqrt x = lift1 X.sqrt x

    let free ~name = dyn (X.free ~name)

    let lt x y = lift2 X.lt x y

    let eq x y = lift2 X.eq x y

    let lam : name:string -> ('a repr -> 'b repr) -> ('a -> 'b) repr =
     fun ~name f ->
      let lam arg = f (dyn arg) in
      Static_lam {name; lam}

    let app : type a b. (a -> b) repr -> a repr -> b repr =
     fun f arg ->
      match f with
      | Static_lam {lam; _} -> lam (prj arg)
      | Dynamic dyn_f -> Dynamic (X.app dyn_f (prj arg))

    let let_ : type a b. name:string -> a repr -> (a repr -> b repr) -> b repr =
     fun ~name m f -> Dynamic (X.let_ ~name (prj m) (fun x -> prj (f (dyn x))))

    let if_ cond ift iff = Dynamic (X.if_ (prj cond) (prj ift) (prj iff))
  end

(* As the type indicates, this is a simplified CPS transform designed to
   lift let-bindings out of subexpressions. Warning: this transformation
   does not check that the ~name arguments (used for pretty printing)
   are globally distinct for let bindings. *)
module Let_lift : Transform =
functor
  (X : S)
  ->
  struct
    type size = X.size

    type 'a cps = {cont : 'b. ('a -> 'b X.repr) -> 'b X.repr}

    type 'a repr = 'a X.repr cps

    let prj term = term.cont (fun x -> x)

    let ret x = {cont = (fun k -> k x)}

    let lift_binop op x y =
      {cont = (fun k -> x.cont (fun x -> y.cont (fun y -> k (op x y))))}

    let lift_unop op x = {cont = (fun k -> x.cont (fun x -> k (op x)))}

    let false_ = ret X.false_

    let true_ = ret X.true_

    let float f = ret (X.float f)

    let int i = ret (X.int i)

    let ( + ) = lift_binop X.( + )

    let ( - ) = lift_binop X.( - )

    let ( * ) = lift_binop X.( * )

    let ( / ) = lift_binop X.( / )

    let max = lift_binop X.max

    let min = lift_binop X.min

    let shift_left x i =
      {cont = (fun k -> x.cont (fun x -> k (X.shift_left x i)))}

    let shift_right x i =
      {cont = (fun k -> x.cont (fun x -> k (X.shift_right x i)))}

    let log2 = lift_unop X.log2

    let sqrt = lift_unop X.sqrt

    let free ~name = ret (X.free ~name)

    let lt = lift_binop X.lt

    let eq = lift_binop X.eq

    let lam ~name (f : 'a repr -> 'b repr) =
      {cont = (fun k -> k (X.lam ~name (fun x -> prj (f (ret x)))))}

    let app f arg = {cont = (fun k -> k (X.app (prj f) (prj arg)))}

    let let_ ~name (m : 'a repr) (f : 'a repr -> 'b repr) : 'b repr =
      {
        cont =
          (fun k -> X.let_ ~name (prj m) (fun mres -> k (prj (f (ret mres)))));
      }

    let if_ cond ift iff =
      {
        cont =
          (fun k -> cond.cont (fun cond -> k @@ X.if_ cond (prj ift) (prj iff)));
      }
  end

(* Instantiate model over partially evaluating & hash-consing cost
   function DSL *)
module Hash_cons_vector = Hash_cons (Eval_linear_combination_impl)
module Eval_to_vector = Beta_normalize (Hash_cons_vector)

module Fold_constants (X : S) = struct
  type size = X.size

  type 'a maybe_const =
    | Int : int -> size maybe_const
    | Float : float -> size maybe_const
    | Bool : bool -> bool maybe_const
    | Not_const : 'a X.repr -> 'a maybe_const

  type 'a repr = 'a maybe_const

  let prj : type a. a maybe_const -> a X.repr = function
    | Int i -> X.int i
    | Float f -> X.float f
    | Bool false -> X.false_
    | Bool true -> X.true_
    | Not_const term -> term

  let inj x = Not_const x

  let false_ = Bool false

  let true_ = Bool true

  let float f = Float f

  let int i = Int i

  let arith_op op_i op_f op_x x y =
    match (x, y) with
    | Int i, Int j -> Int (op_i i j)
    | Float i, Float j -> Float (op_f i j)
    | Int i, Float j -> Float (op_f (float_of_int i) j)
    | Float i, Int j -> Float (op_f i (float_of_int j))
    | Not_const term, Int i -> Not_const (op_x term (X.int i))
    | Int i, Not_const term -> Not_const (op_x (X.int i) term)
    | Not_const term, Float i -> Not_const (op_x term (X.float i))
    | Float i, Not_const term -> Not_const (op_x (X.float i) term)
    | Not_const x, Not_const y -> Not_const (op_x x y)
    | Bool _, _ | _, Bool _ -> assert false

  let ( + ) x y =
    match (x, y) with
    | Int 0, term | Float 0.0, term | term, Int 0 | term, Float 0.0 -> term
    | _ -> arith_op ( + ) ( +. ) X.( + ) x y

  let ( * ) x y =
    match (x, y) with
    | Int 0, _ | Float 0.0, _ | _, Int 0 | _, Float 0.0 -> Int 0
    | Int 1, term | Float 1.0, term | term, Int 1 | term, Float 1.0 -> term
    | _ -> arith_op ( * ) ( *. ) X.( * ) x y

  let ( - ) x y =
    match (x, y) with
    | term, Int 0 | term, Float 0.0 -> term
    | _ -> arith_op ( - ) ( -. ) X.( - ) x y

  let ( / ) x y =
    match (x, y) with
    | term, Int 1 -> term
    | term, Float 1.0 -> term
    (* The next cases are here to avoid introducing floating point constants from the division *)
    | Int i, Int j -> Not_const X.(int i / int j)
    | Float i, Float j -> Not_const X.(float i / float j)
    | Int i, Float j -> Not_const X.(int i / float j)
    | Float i, Int j -> Not_const X.(float i / int j)
    | _ -> arith_op ( / ) ( /. ) X.( / ) x y

  let max = arith_op max max X.max

  let min = arith_op min min X.min

  let shift_left x s =
    inj
    @@
    match x with
    | Int i -> X.(shift_left (int i) s)
    | Float f -> X.(shift_left (float f) s)
    | Not_const term -> X.(shift_left term s)
    | Bool _ -> assert false

  let shift_right x s =
    inj
    @@
    match x with
    | Int i -> X.(shift_right (int i) s)
    | Float f -> X.(shift_right (float f) s)
    | Not_const term -> X.(shift_right term s)
    | Bool _ -> assert false

  let log2 x =
    inj
    @@
    match x with
    | Int i -> X.(log2 (int i))
    | Float f -> X.(log2 (float f))
    | Not_const term -> X.(log2 term)
    | Bool _ -> assert false

  let sqrt x =
    inj
    @@
    match x with
    | Int i -> X.(sqrt (int i))
    | Float f -> X.(sqrt (float f))
    | Not_const term -> X.(sqrt term)
    | Bool _ -> assert false

  let free ~name = Not_const (X.free ~name)

  let lt x y =
    match (x, y) with
    | Int i, Int j -> Bool (i < j)
    | Float i, Float j -> Bool (i < j)
    | Float i, Int j -> Bool (i < float_of_int j)
    | Int i, Float j -> Bool (float_of_int i < j)
    | Not_const term, Int i -> Not_const X.(lt term (int i))
    | Int i, Not_const term -> Not_const X.(lt (int i) term)
    | Not_const term, Float i -> Not_const X.(lt term (float i))
    | Float i, Not_const term -> Not_const X.(lt (float i) term)
    | Not_const x, Not_const y -> Not_const X.(lt x y)
    | Bool _, _ | _, Bool _ -> assert false

  let eq x y =
    match (x, y) with
    | Int i, Int j -> Bool (i = j)
    | Float i, Float j -> Bool (i = j)
    | Float i, Int j -> Bool (i = float_of_int j)
    | Int i, Float j -> Bool (float_of_int i = j)
    | Not_const term, Int i -> Not_const X.(eq term (int i))
    | Int i, Not_const term -> Not_const X.(eq (int i) term)
    | Not_const term, Float i -> Not_const X.(eq term (float i))
    | Float i, Not_const term -> Not_const X.(eq (float i) term)
    | Not_const x, Not_const y -> Not_const X.(eq x y)
    | Bool _, _ | _, Bool _ -> assert false

  let lam ~name (f : 'a repr -> 'b repr) =
    Not_const (X.lam ~name (fun x -> prj (f (inj x))))

  let app f arg = Not_const (X.app (prj f) (prj arg))

  let let_ ~name (m : 'a repr) (f : 'a repr -> 'b repr) : 'b repr =
    Not_const (X.let_ ~name (prj m) (fun x -> prj (f (inj x))))

  let if_ cond ift iff =
    match cond with
    | Bool true -> ift
    | Bool false -> iff
    | Not_const term -> Not_const (X.if_ term (prj ift) (prj iff))
    | Int _ | Float _ -> assert false
end
back to top