https://github.com/EasyCrypt/easycrypt
Raw File
Tip revision: 3f4a0bd5596888cd8d28b97687d477942187aa5f authored by Pierre-Yves Strub on 11 June 2022, 06:10:21 UTC
In loop fusion/fission, add more constraints on the epilog
Tip revision: 3f4a0bd
ecFol.ml
(* -------------------------------------------------------------------- *)
open EcIdent
open EcUtils
open EcSymbols
open EcTypes
open EcMemory
open EcBigInt.Notations
open EcBaseLogic

module BI = EcBigInt
module CI = EcCoreLib

(* -------------------------------------------------------------------- *)
include EcCoreFol

(* -------------------------------------------------------------------- *)
let f_eqparams ty1 vs1 m1 ty2 vs2 m2 =
  let f_pvlocs ty vs m =
    let arg = f_pvarg ty m in
    if List.length vs = 1 then [arg]
    else
      let t = Array.of_list vs in
      let t = Array.mapi (fun i vd -> f_proj arg i vd.ov_type) t in
      Array.to_list t
  in

  if   List.length vs1 = List.length vs2
  then f_eqs (f_pvlocs ty1 vs1 m1) (f_pvlocs ty2 vs2 m2)
  else f_eq  (f_tuple (f_pvlocs ty1 vs1 m1))
             (f_tuple (f_pvlocs ty2 vs2 m2))

let f_eqres ty1 m1 ty2 m2 =
  f_eq (f_pvar pv_res ty1 m1) (f_pvar pv_res ty2 m2)

let f_eqglob mp1 m1 mp2 m2 =
  f_eq (f_glob mp1 m1) (f_glob mp2 m2)

(* -------------------------------------------------------------------- *)
let f_op_real_of_int = (* CORELIB *)
  f_op CI.CI_Real.p_real_of_int [] (tfun tint treal)

let f_real_of_int f  = f_app f_op_real_of_int [f] treal
let f_rint n         = f_real_of_int (f_int n)

let f_r0 = f_rint BI.zero
let f_r1 = f_rint BI.one

let destr_rint f =
  match f.f_node with
  | Fapp (op, [f1]) when f_equal f_op_real_of_int op -> begin
      try destr_int f1 with DestrError _ -> destr_error "destr_rint"
  end

  | Fop (p, _) when EcPath.p_equal p CI.CI_Real.p_real0 -> BI.zero
  | Fop (p, _) when EcPath.p_equal p CI.CI_Real.p_real1 -> BI.one

  | _ -> destr_error "destr_rint"


(* -------------------------------------------------------------------- *)
let fop_int_le     = f_op CI.CI_Int .p_int_le    [] (toarrow [tint ; tint ] tbool)
let fop_int_lt     = f_op CI.CI_Int .p_int_lt    [] (toarrow [tint ; tint ] tbool)
let fop_real_le    = f_op CI.CI_Real.p_real_le   [] (toarrow [treal; treal] tbool)
let fop_real_lt    = f_op CI.CI_Real.p_real_lt   [] (toarrow [treal; treal] tbool)
let fop_real_add   = f_op CI.CI_Real.p_real_add  [] (toarrow [treal; treal] treal)
let fop_real_opp   = f_op CI.CI_Real.p_real_opp  [] (toarrow [treal] treal)
let fop_real_mul   = f_op CI.CI_Real.p_real_mul  [] (toarrow [treal; treal] treal)
let fop_real_inv   = f_op CI.CI_Real.p_real_inv  [] (toarrow [treal]        treal)
let fop_real_abs   = f_op CI.CI_Real.p_real_abs  [] (toarrow [treal]        treal)

let f_int_le f1 f2 = f_app fop_int_le [f1; f2] tbool
let f_int_lt f1 f2 = f_app fop_int_lt [f1; f2] tbool

(* -------------------------------------------------------------------- *)
let f_real_le  f1 f2 = f_app fop_real_le  [f1; f2] tbool
let f_real_lt  f1 f2 = f_app fop_real_lt  [f1; f2] tbool
let f_real_add f1 f2 = f_app fop_real_add [f1; f2] treal
let f_real_opp f     = f_app fop_real_opp [f]      treal
let f_real_mul f1 f2 = f_app fop_real_mul [f1; f2] treal
let f_real_inv f     = f_app fop_real_inv [f]      treal
let f_real_abs f     = f_app fop_real_abs [f]      treal

let f_real_sub f1 f2 =
  f_real_add f1 (f_real_opp f2)

let f_real_div f1 f2 =
  f_real_mul f1 (f_real_inv f2)

let f_decimal (n, (l, f)) =
  if   EcBigInt.equal f EcBigInt.zero
  then f_real_of_int (f_int n)
  else

  let d   = EcBigInt.pow (EcBigInt.of_int 10) l in
  let gcd = EcBigInt.gcd f d in
  let f   = EcBigInt.div f gcd in
  let d   = EcBigInt.div d gcd in
  let fct = f_real_div (f_real_of_int (f_int f)) (f_real_of_int (f_int d)) in

  if   EcBigInt.equal n EcBigInt.zero
  then fct
  else f_real_add (f_real_of_int (f_int n)) fct

(* -------------------------------------------------------------------- *)
let tmap aty bty =
  tconstr CI.CI_Map.p_map [aty; bty]

let fop_map_cst aty bty =
  f_op CI.CI_Map.p_cst [aty; bty] (toarrow [bty] (tmap aty bty))

let fop_map_get aty bty =
  f_op CI.CI_Map.p_get [aty; bty] (toarrow [tmap aty bty; aty] bty)

let fop_map_set aty bty =
  f_op CI.CI_Map.p_set [aty; bty]
    (toarrow [tmap aty bty; aty; bty] (tmap aty bty))

let f_map_cst aty f =
  f_app (fop_map_cst aty f.f_ty) [f] (tmap aty f.f_ty)

let f_map_get m x bty =
  f_app (fop_map_get x.f_ty bty) [m;x] bty

let f_map_set m x e =
  f_app (fop_map_set x.f_ty e.f_ty) [m;x;e] (tmap x.f_ty e.f_ty)

(* -------------------------------------------------------------------- *)
let f_predT     ty = f_op CI.CI_Pred.p_predT [ty] (tcpred ty)
let fop_pred1   ty = f_op CI.CI_Pred.p_pred1 [ty] (toarrow [ty; ty] tbool)

let fop_support ty =
  f_op CI.CI_Distr.p_support  [ty] (toarrow [tdistr ty; ty] tbool)
let fop_mu      ty =
  f_op CI.CI_Distr.p_mu       [ty] (toarrow [tdistr ty; tcpred ty] treal)
let fop_lossless ty =
  f_op CI.CI_Distr.p_lossless [ty] (toarrow [tdistr ty] tbool)

let f_support f1 f2 = f_app (fop_support f2.f_ty) [f1; f2] tbool
let f_in_supp f1 f2 = f_support f2 f1
let f_pred1   f1    = f_app (fop_pred1 f1.f_ty) [f1] (toarrow [f1.f_ty] tbool)

let f_mu_x    f1 f2 =
  f_app (fop_mu f2.f_ty) [f1; (f_pred1 f2)] treal

let proj_distr_ty env ty =
   match (EcEnv.Ty.hnorm ty env).ty_node with
  | Tconstr(_,lty) when List.length lty = 1  ->
    List.hd lty
  | _ -> assert false

let f_mu env f1 f2 =
  f_app (fop_mu (proj_distr_ty env f1.f_ty)) [f1; f2] treal

let f_weight ty d =
  f_app (fop_mu ty) [d; f_predT ty] treal

let f_lossless ty d =
  f_app (fop_lossless ty) [d] tbool

(* -------------------------------------------------------------------- *)
let f_losslessF f = f_bdHoareF f_true f f_true FHeq f_r1

(* -------------------------------------------------------------------- *)
let f_identity ?(name = "x") ty =
  let name  = EcIdent.create name in
    f_lambda [name, GTty ty] (f_local name ty)

(* -------------------------------------------------------------------- *)
let f_ty_app (env : EcEnv.env) (f : form) (args : form list) =
  let ty, rty = EcEnv.Ty.decompose_fun f.f_ty env in
  let ty, ety =
    try  List.split_at (List.length args) ty
    with Failure _ -> assert false in

  ignore ty; f_app f args (toarrow ety rty)

(* -------------------------------------------------------------------- *)
module type DestrRing = sig
  val le  : form -> form * form
  val lt  : form -> form * form
  val add : form -> form * form
  val opp : form -> form
  val sub : form -> form * form
  val mul : form -> form * form
end

(* -------------------------------------------------------------------- *)
module DestrInt : DestrRing = struct
  let le  = destr_app2_eq ~name:"int_le"  CI.CI_Int.p_int_le
  let lt  = destr_app2_eq ~name:"int_lt"  CI.CI_Int.p_int_lt
  let add = destr_app2_eq ~name:"int_add" CI.CI_Int.p_int_add
  let opp = destr_app1_eq ~name:"int_opp" CI.CI_Int.p_int_opp
  let mul = destr_app2_eq ~name:"int_mul" CI.CI_Int.p_int_mul

  let sub f =
    try  snd_map opp (add f)
    with DestrError _ -> raise (DestrError "int_sub")
end

(* -------------------------------------------------------------------- *)
module type DestrReal = sig
  include DestrRing

  val inv : form -> form
  val div : form -> form * form
  val abs : form -> form
end

module DestrReal : DestrReal = struct
  let le  = destr_app2_eq ~name:"real_le"  CI.CI_Real.p_real_le
  let lt  = destr_app2_eq ~name:"real_lt"  CI.CI_Real.p_real_lt
  let add = destr_app2_eq ~name:"real_add" CI.CI_Real.p_real_add
  let opp = destr_app1_eq ~name:"real_opp" CI.CI_Real.p_real_opp
  let mul = destr_app2_eq ~name:"real_mul" CI.CI_Real.p_real_mul
  let inv = destr_app1_eq ~name:"real_inv" CI.CI_Real.p_real_inv
  let abs = destr_app1_eq ~name:"real_abs" CI.CI_Real.p_real_abs

  let sub f =
    try  snd_map opp (add f)
    with DestrError _ -> raise (DestrError "real_sub")

  let div f =
    try  snd_map inv (mul f)
    with DestrError _ -> raise (DestrError "int_sub")
end

(* -------------------------------------------------------------------- *)
let f_int_opp_simpl f =
  match f.f_node with
  | Fapp (op, [f]) when f_equal op fop_int_opp -> f
  | _ -> if f_equal f_i0 f then f_i0 else f_int_opp f

(* -------------------------------------------------------------------- *)
let f_int_add_simpl =
  let try_add_opp f1 f2 =
    try
      let f2 = DestrInt.opp f2 in
      if f_equal f1 f2 then Some f_i0 else None
    with DestrError _ -> None in

  let try_addc i f =
    try
      let c1, c2 = DestrInt.add f in

      try  let c = destr_int c1 in Some (f_int_add (f_int (c +^ i)) c2)
      with DestrError _ ->
      try  let c = destr_int c2 in Some (f_int_add c1 (f_int (c +^ i)))
      with DestrError _ -> None

    with DestrError _ -> None in

  fun f1 f2 ->
    let i1 = try Some (destr_int f1) with DestrError _ -> None in
    let i2 = try Some (destr_int f2) with DestrError _ -> None in

    match i1, i2 with
    | Some i1, Some i2 -> f_int (i1 +^ i2)

    | Some i1, _ when i1 =^ EcBigInt.zero -> f2
    | _, Some i2 when i2 =^ EcBigInt.zero -> f1

    | _, _ ->
        let simpls = [
           (fun () -> try_add_opp f1 f2);
           (fun () -> try_add_opp f2 f1);
           (fun () -> i1 |> obind (try_addc^~ f2));
           (fun () -> i2 |> obind (try_addc^~ f1));
        ] in

        ofdfl
          (fun () -> f_int_add f1 f2)
          (List.Exceptionless.find_map (fun f -> f ()) simpls)

(* -------------------------------------------------------------------- *)
let f_int_max_simpl f1 f2 =
    let i1 = try Some (destr_int f1) with DestrError _ -> None in
    let i2 = try Some (destr_int f2) with DestrError _ -> None in

    match i1, i2 with
    | Some i1, Some i2 -> f_int (EcBigInt.max i1 i2)
    | _ -> f_int_max f1 f2

(* -------------------------------------------------------------------- *)
let f_int_sub_simpl f1 f2 =
  f_int_add_simpl f1 (f_int_opp_simpl f2)

(* -------------------------------------------------------------------- *)
let f_int_mul_simpl f1 f2 =
  try  f_int (destr_int f1 *^ destr_int f2)
  with DestrError _ ->
         if f_equal f_i0 f1 || f_equal f_i0 f2 then f_i0
    else if f_equal f_i1 f1 then f2
    else if f_equal f_i1 f2 then f1
    else f_int_mul f1 f2

(* -------------------------------------------------------------------- *)
let f_int_edivz_simpl f1 f2 =
  if f_equal f2 f_i0 then f_tuple [f_i0; f1]
  else
    try
      let q,r = BI.ediv (destr_int f1) (destr_int f2) in
      f_tuple [f_int q; f_int r]
    with DestrError _ ->
      if f_equal f1 f_i0 then f_tuple [f_i0; f_i0]
      else if f_equal f2 f_i1 then f_tuple [f1; f_i0]
      else if f_equal f2 f_im1 then f_tuple [f_int_opp_simpl f1; f_i0]
      else f_int_edivz f1 f2

(* -------------------------------------------------------------------- *)
let destr_rdivint =
  let rec aux isneg f =
    let renorm n d =
      if isneg then (BI.neg n, d) else (n, d)
    in

    match f.f_node with
    | Fapp (op, [f1; { f_node = Fapp (subop, [f2]) }])
        when f_equal    op fop_real_mul
          && f_equal subop fop_real_inv -> begin
        let n1, n2 =
          try  (destr_rint f1, destr_rint f2)
          with DestrError _ -> destr_error "rdivint"
        in renorm n1 n2
      end

    | Fapp (op, [f]) when f_equal op fop_real_inv -> begin
        try
          renorm BI.one (destr_rint f)
        with DestrError _ -> destr_error "rdivint"
      end

    | Fapp (op, [f]) when f_equal op fop_real_opp ->
       aux (not isneg) f

    | _ ->
       try  renorm (destr_rint f) BI.one
       with DestrError _ -> destr_error "rdivint"

  in fun f -> aux false f

let real_split f =
  match f.f_node with
  | Fapp (op, [f1; { f_node = Fapp (subop, [f2]) }])
      when f_equal    op fop_real_mul
        && f_equal subop fop_real_inv
    -> (f1, f2)

  | Fapp (op, [{ f_node = Fapp (subop, [f1]) }; f2])
      when f_equal    op fop_real_mul
        && f_equal subop fop_real_inv
    -> (f2, f1)

  | Fapp (op, [f]) when f_equal op fop_real_inv ->
     (f_r1, f)

  | _ ->
     (f, f_r1)

and real_is_zero f =
  try  BI.equal BI.zero (destr_rint f)
  with DestrError _ -> false

and real_is_one f =
  try  BI.equal BI.one (destr_rint f)
  with DestrError _ -> false

let norm_real_int_div n1 n2 =
  let s1 = BI.sign n1 and s2 = BI.sign n2 in
  if s1 = 0 || s2 = 0 then f_r0
  else
    let n1 = BI.abs n1 and n2 = BI.abs n2 in
    let n1, n2 =
      match BI.gcd n1 n2 with
      | n when BI.equal n BI.one -> (n1, n2)
      | n -> (n1/^n, n2/^n)
    in
    let n1 = if (s1 * s2) < 0 then BI.neg n1 else n1 in
    if BI.equal n2 BI.one then f_rint n1
    else f_real_div (f_rint n1) (f_rint n2)

let f_real_add_simpl =
  let try_add_opp f1 f2 =
    try
      let f2 = DestrReal.opp f2 in
      if f_equal f1 f2 then Some f_r0 else None
    with DestrError _ -> None in

  let try_addc i f =
    try
      let c1, c2 = DestrReal.add f in

      try  let c = destr_rint c1 in Some (f_real_add (f_rint (c +^ i)) c2)
      with DestrError _ ->
      try  let c = destr_rint c2 in Some (f_real_add c1 (f_rint (c +^ i)))
      with DestrError _ -> None

    with DestrError _ -> None in

  let try_norm_rintdiv f1 f2 =
    try
      let (n1, d1) = destr_rdivint f1 in
      let (n2, d2) = destr_rdivint f2 in

      Some (norm_real_int_div (n1*^d2 +^ n2*^d1) (d1*^d2))

    with DestrError _ -> None in

  fun f1 f2 ->
    let r1 = try Some (destr_rint f1) with DestrError _ -> None in
    let r2 = try Some (destr_rint f2) with DestrError _ -> None in

    match r1, r2 with
    | Some i1, Some i2 -> f_rint (i1 +^ i2)

    | Some i1, _ when i1 =^ EcBigInt.zero -> f2
    | _, Some i2 when i2 =^ EcBigInt.zero -> f1

    | _, _ ->
        let simpls = [
           (fun () -> try_norm_rintdiv f1 f2);
           (fun () -> try_add_opp f1 f2);
           (fun () -> try_add_opp f2 f1);
           (fun () -> r1 |> obind (try_addc^~ f2));
           (fun () -> r2 |> obind (try_addc^~ f1));
        ] in

        ofdfl
          (fun () -> f_real_add f1 f2)
          (List.Exceptionless.find_map (fun f -> f ()) simpls)

let f_real_opp_simpl f =
  match f.f_node with
  | Fapp (op, [f]) when f_equal op fop_real_opp -> f
  | _ -> if real_is_zero f then f_r0 else f_real_opp f

let f_real_sub_simpl f1 f2 =
  f_real_add_simpl f1 (f_real_opp_simpl f2)

let rec f_real_mul_simpl f1 f2 =
  let (n1, d1) = real_split f1 in
  let (n2, d2) = real_split f2 in

  f_real_div_simpl_r
    (f_real_mul_simpl_r n1 n2)
    (f_real_mul_simpl_r d1 d2)

and f_real_div_simpl f1 f2 =
  let (n1, d1) = real_split f1 in
  let (n2, d2) = real_split f2 in

  f_real_div_simpl_r
    (f_real_mul_simpl_r n1 d2)
    (f_real_mul_simpl_r d1 n2)

and f_real_mul_simpl_r f1 f2 =
  if real_is_zero f1 || real_is_zero f2 then f_r0 else

  if real_is_one f1 then f2 else
  if real_is_one f2 then f1 else

  try
    f_rint (destr_rint f1 *^ destr_rint f2)
  with DestrError _ ->
    f_real_mul f1 f2

and f_real_div_simpl_r f1 f2 =
  let (f1, f2) =
    try
      let n1 = destr_rint f1 in
      let n2 = destr_rint f2 in
      let gd = BI.gcd n1 n2 in

      f_rint (BI.div n1 gd), f_rint (BI.div n2 gd)

    with
    | DestrError _ -> (f1, f2)
    | Division_by_zero -> (f_r0, f_r1)

  in f_real_mul_simpl_r f1 (f_real_inv_simpl f2)

and f_real_inv_simpl f =
  match f.f_node with
  | Fapp (op, [f]) when f_equal op fop_real_inv -> f

  | _ ->
     try
       match destr_rint f with
       | n when BI.equal n BI.zero -> f_r0
       | n when BI.equal n BI.one  -> f_r1
       | _ -> destr_error "destr_rint/inv"
     with DestrError _ -> f_app fop_real_inv [f] treal

(* -------------------------------------------------------------------- *)
let rec f_let_simpl lp f1 f2 =
  match lp with
  | LSymbol (id, _) -> begin
      match Mid.find_opt id (f_fv f2) with
      | None   -> f2
      | Some i ->
          if   i = 1 || can_subst f1
          then Fsubst.f_subst_local id f1 f2
          else f_let lp f1 f2
    end

  | LTuple ids -> begin
      match f1.f_node with
      | Ftuple fs ->
          let (d, s) =
            List.fold_left2 (fun (d, s) (id, ty) f1 ->
              match Mid.find_opt id (f_fv f2) with
              | None   -> (d, s)
              | Some i ->
                  if   i = 1 || can_subst f1
                  then (d, Mid.add id f1 s)
                  else (((id, ty), f1) :: d, s))
              ([], Mid.empty) ids fs
          in
            List.fold_left
              (fun f2 (id, f1) -> f_let (LSymbol id) f1 f2)
              (Fsubst.subst_locals s f2) d
      | _ ->
        let x = EcIdent.create "tpl" in
        let ty = ttuple (List.map snd ids) in
        let lpx = LSymbol(x,ty) in
        let fx = f_local x ty in
        let tu = f_tuple (List.mapi (fun i (_,ty') -> f_proj fx i ty') ids) in
        f_let_simpl lpx f1 (f_let_simpl lp tu f2)
    end

  | LRecord (_, ids) ->
      let check (id, _) =
        id |> omap (fun id -> not (Mid.mem id (f_fv f2))) |> odfl true
      in if List.for_all check ids then f2 else f_let lp f1 f2

let f_lets_simpl =
  (* FIXME : optimize this *)
  List.fold_right (fun (lp,f1) f2 -> f_let_simpl lp f1 f2)

let rec f_app_simpl f args ty =
  f_betared (f_app f args ty)

and f_betared f =
  let tx fo fp = if f_equal fo fp || can_betared fo then fp else f_betared fp in

  match f.f_node with
  | Fapp ({ f_node = Fquant (Llambda, bds, body)}, args) ->
      let (bds1, bds2), (args1, args2) = List.prefix2 bds args in
      let bind  = fun subst (x, _) arg -> Fsubst.f_bind_local subst x arg in
      let subst = Fsubst.f_subst_id in
      let subst = List.fold_left2 bind subst bds1 args1 in
      f_app (f_quant Llambda bds2 (Fsubst.f_subst ~tx subst body)) args2 f.f_ty
  | _ -> f

and can_betared f =
  match f.f_node with
  | Fapp ({ f_node = Fquant (Llambda, _, _)}, _) -> true
  | _ -> false

let rec f_forall_simpl bs f =
  match bs with
  | [] -> f
  | (b, ty) :: bs ->
    let f = f_forall_simpl bs f in
    if Mid.mem b (f_fv f) then f_forall [b, ty] f else f

let rec f_exists_simpl bs f =
  match bs with
  | [] -> f
  | (b, ty) :: bs ->
    let f = f_exists_simpl bs f in
    if Mid.mem b (f_fv f) then f_exists [b, ty] f else f

let f_not_simpl f =
  if is_not f then destr_not f
  else if is_true f then f_false
  else if is_false f then f_true
  else f_not f

let f_and_simpl f1 f2 =
  if is_true f1 then f2
  else if is_false f1 then f_false
  else if is_true f2 then f1
  else if is_false f2 then f_false
  else f_and f1 f2

let f_ands_simpl = List.fold_right f_and_simpl

let f_ands0_simpl fs =
  match List.rev fs with
  | [] -> f_true
  | [x] -> x
  | f::fs -> f_ands_simpl (List.rev fs) f

let f_anda_simpl f1 f2 =
  if is_true f1 then f2
  else if is_false f1 then f_false
  else if is_true f2 then f1
  else if is_false f2 then f_false
  else f_anda f1 f2

let f_andas_simpl = List.fold_right f_anda_simpl

let f_or_simpl f1 f2 =
  if is_true f1 then f_true
  else if is_false f1 then f2
  else if is_true f2 then f_true
  else if is_false f2 then f1
  else f_or f1 f2

let f_ora_simpl f1 f2 =
  if is_true f1 then f_true
  else if is_false f1 then f2
  else if is_true f2 then f_true
  else if is_false f2 then f1
  else f_ora f1 f2

let f_imp_simpl f1 f2 =
  if is_true f1 then f2
  else if is_false f1 || is_true f2 then f_true
  else if is_false f2 then f_not_simpl f1
  else
    if f_equal f1 f2 then f_true
    else f_imp f1 f2
    (* FIXME : simplify x = f1 => f2 into x = f1 => f2{x<-f2} *)

let bool_val f =
  if is_true f then Some true
  else if is_false f then Some false
  else None

let f_proj_simpl f i ty =
  match f.f_node with
  | Ftuple args -> List.nth args i
  | _ -> f_proj f i ty

let f_if_simpl f1 f2 f3 =
  if f_equal f2 f3 then f2
  else match bool_val f1, bool_val f2, bool_val f3 with
  | Some true, _, _  -> f2
  | Some false, _, _ -> f3
  | _, Some true, _  -> f_imp_simpl (f_not_simpl f1) f3
  | _, Some false, _ -> f_anda_simpl (f_not_simpl f1) f3
  | _, _, Some true  -> f_imp_simpl f1 f2
  | _, _, Some false -> f_anda_simpl f1 f2
  | _, _, _          -> f_if f1 f2 f3

let f_imps_simpl = List.fold_right f_imp_simpl

let rec f_iff_simpl f1 f2 =
       if f_equal  f1 f2 then f_true
  else if is_true  f1    then f2
  else if is_false f1    then f_not_simpl f2
  else if is_true  f2    then f1
  else if is_false f2    then f_not_simpl f1
  else
    match f1.f_node, f2.f_node with
    | Fapp ({f_node = Fop (op1, [])}, [f1]),
      Fapp ({f_node = Fop (op2, [])}, [f2]) when
        (EcPath.p_equal op1 CI.CI_Bool.p_not &&
         EcPath.p_equal op2 CI.CI_Bool.p_not)
        -> f_iff_simpl f1 f2
    | _ -> f_iff f1 f2

(* Lift a binary comparison over [txint] to cost record. *)
let cost_mk_cmp
    (fullcmp : bool -> bool -> bool)
    (xcmp    : form -> form -> form)
    (c1      : cost)
    (c2      : cost) : form
  =
  let full = if fullcmp c1.c_full c2.c_full then f_true else f_false in
  let self = xcmp c1.c_self c2.c_self in
  let calls =
    EcPath.Mx.fold2_union (fun _ x1 x2 forms ->
        let x1 = oget_c_bnd x1 c1.c_full
        and x2 = oget_c_bnd x2 c2.c_full in
        xcmp x1 x2 :: forms
      ) c1.c_calls c2.c_calls []
  in
  f_ands0_simpl (full :: self :: (List.rev calls))

let rec f_eq_simpl f1 f2 =
  if f_equal f1 f2 then f_true
  else match f1.f_node, f2.f_node with
  | Fint _ , Fint _ -> f_false

  | Fapp(op, [{f_node = Fint i1}]), Fint i2
      when f_equal op fop_int_opp ->
    f_bool (EcBigInt.equal (EcBigInt.neg i1) i2)

  | Fint i1, Fapp(op, [{f_node = Fint i2}])
      when f_equal op fop_int_opp ->
     f_bool (EcBigInt.equal i1 (EcBigInt.neg i2))

  | Fapp (op1, [{f_node = Fint _}]), Fapp (op2, [{f_node = Fint _}])
      when f_equal op1 f_op_real_of_int &&
           f_equal op2 f_op_real_of_int
    -> f_false
  | Fop (op1, []), Fop (op2, []) when
         (EcPath.p_equal op1 CI.CI_Bool.p_true  &&
          EcPath.p_equal op2 CI.CI_Bool.p_false  )
      || (EcPath.p_equal op2 CI.CI_Bool.p_true  &&
          EcPath.p_equal op1 CI.CI_Bool.p_false  )
    -> f_false

  | Ftuple fs1, Ftuple fs2 when List.length fs1 = List.length fs2 ->
      f_ands_simpl (List.map2 f_eq_simpl fs1 fs2) f_true

  | Fcost c1, Fcost c2 ->
    cost_mk_cmp (=) f_eq_simpl c1 c2

  | Fmodcost mc1, Fmodcost mc2 ->
    let similar =
      let exception Fail in
      try
        Msym.fold2_union (fun _ pc1 pc2 () ->
            match pc1, pc2 with
            | Some _, Some _ -> ()
            | _ -> raise Fail
          ) mc1 mc2 ();
        true
        with Fail -> false
    in
    if similar then
      Msym.fold2_union (fun _ pc1 pc2 cond ->
          let pc1, pc2 = oget pc1, oget pc2 in
          f_and_simpl (cost_mk_cmp (=) f_eq_simpl pc1 pc2) cond
        ) mc1 mc2 f_true
    else f_eq f1 f2


  | _ -> f_eq f1 f2

(* -------------------------------------------------------------------- *)
type op_kind = [
  | `True
  | `False
  | `Not
  | `And   of [`Asym | `Sym]
  | `Or    of [`Asym | `Sym]
  | `Imp
  | `Iff
  | `Eq
  | `Int_le
  | `Int_lt
  | `Real_le
  | `Real_lt
  | `Int_add
  | `Int_mul
  | `Int_max
  | `Int_pow
  | `Int_opp
  | `Int_edivz

  | `Cost_add
  | `Cost_opp
  | `Cost_scale
  | `Cost_xscale
  | `Cost_le
  | `Cost_lt
  | `Cost_big
  | `Cost_is_int

  | `Real_add
  | `Real_opp
  | `Real_mul
  | `Real_inv
  | `Map_get
  | `Map_set
  | `Map_cst
]

let operators =
  let operators =
    [CI.CI_Bool.p_true    , `True     ;
     CI.CI_Bool.p_false   , `False    ;
     CI.CI_Bool.p_not     , `Not      ;
     CI.CI_Bool.p_anda    , `And `Asym;
     CI.CI_Bool.p_and     , `And `Sym ;
     CI.CI_Bool.p_ora     , `Or  `Asym;
     CI.CI_Bool.p_or      , `Or  `Sym ;
     CI.CI_Bool.p_imp     , `Imp      ;
     CI.CI_Bool.p_iff     , `Iff      ;
     CI.CI_Bool.p_eq      , `Eq       ;
     CI.CI_Int .p_int_le  , `Int_le   ;
     CI.CI_Int .p_int_lt  , `Int_lt   ;
     CI.CI_Int .p_int_add , `Int_add  ;
     CI.CI_Int .p_int_opp , `Int_opp  ;
     CI.CI_Int .p_int_mul , `Int_mul  ;
     CI.CI_Int .p_int_max , `Int_max  ;
     CI.CI_Int .p_int_pow , `Int_pow  ;
     CI.CI_Int .p_int_edivz , `Int_edivz  ;

     CI.CI_Cost.p_cost_le    , `Cost_le    ;
     CI.CI_Cost.p_cost_lt    , `Cost_lt    ;
     CI.CI_Cost.p_cost_add   , `Cost_add   ;
     CI.CI_Cost.p_cost_opp   , `Cost_opp   ;
     CI.CI_Cost.p_cost_scale , `Cost_scale ;
     CI.CI_Cost.p_cost_xscale, `Cost_xscale;
     CI.CI_Cost.p_cost_is_int, `Cost_is_int;
     CI.CI_Xint.p_bigcost    , `Cost_big   ;


     CI.CI_Real.p_real_add, `Real_add ;
     CI.CI_Real.p_real_opp, `Real_opp ;
     CI.CI_Real.p_real_mul, `Real_mul ;
     CI.CI_Real.p_real_inv, `Real_inv ;
     CI.CI_Real.p_real_le , `Real_le  ;
     CI.CI_Real.p_real_lt , `Real_lt  ;
     CI.CI_Map.p_get      , `Map_get  ;
     CI.CI_Map.p_set      , `Map_set  ;
     CI.CI_Map.p_cst      , `Map_cst  ;
  ]
  in

  let tbl = EcPath.Hp.create 11 in
    List.iter (fun (p, k) -> EcPath.Hp.add tbl p k) operators;
    tbl

(* -------------------------------------------------------------------- *)
let op_kind (p : EcPath.path) : op_kind option =
  EcPath.Hp.find_opt operators p

(* -------------------------------------------------------------------- *)
let is_logical_op op =
  match op_kind op with
  | Some (
        `Not | `And _ | `Or _ | `Imp | `Iff | `Eq
      | `Int_le   | `Int_lt   | `Real_le  | `Real_lt
      | `Int_add  | `Int_opp  | `Int_mul | `Int_edivz
      | `Real_add | `Real_opp | `Real_mul | `Real_inv
      | `Map_get  | `Map_set  | `Map_cst
   ) -> true

  | _ -> false

(* -------------------------------------------------------------------- *)
type sform =
  | SFint   of BI.zint
  | SFlocal of EcIdent.t
  | SFpvar  of EcTypes.prog_var * memory
  | SFglob  of EcPath.mpath * memory

  | SFif    of form * form * form
  | SFmatch of form * form list * ty
  | SFlet   of lpattern * form * form
  | SFtuple of form list
  | SFproj  of form * int
  | SFquant of quantif * (EcIdent.t * gty) * form Lazy.t
  | SFtrue
  | SFfalse
  | SFnot   of form
  | SFand   of [`Asym | `Sym] * (form * form)
  | SFor    of [`Asym | `Sym] * (form * form)
  | SFimp   of form * form
  | SFiff   of form * form
  | SFeq    of form * form
  | SFop    of (EcPath.path * ty list) * (form list)

  | SFcost of cost
  | SFmodcost of mod_cost

  | SFhoareF  of sHoareF
  | SFhoareS  of sHoareS
  | SFcHoareF  of cHoareF
  | SFcHoareS  of cHoareS
  | SFbdHoareF of bdHoareF
  | SFbdHoareS of bdHoareS
  | SFequivF   of equivF
  | SFequivS   of equivS
  | SFpr       of pr

  | SFother of form

let sform_of_op (op, ty) args =
  match op_kind op, args with
  | Some (`True ), []       -> SFtrue
  | Some (`False), []       -> SFfalse
  | Some (`Not  ), [f]      -> SFnot f
  | Some (`And b), [f1; f2] -> SFand (b, (f1, f2))
  | Some (`Or  b), [f1; f2] -> SFor  (b, (f1, f2))
  | Some (`Imp  ), [f1; f2] -> SFimp (f1, f2)
  | Some (`Iff  ), [f1; f2] -> SFiff (f1, f2)
  | Some (`Eq   ), [f1; f2] -> SFeq  (f1, f2)

  | _ -> SFop ((op, ty), args)

let rec sform_of_form fp =
  match fp.f_node with
  | Fint   i      -> SFint   i
  | Flocal x      -> SFlocal x
  | Fpvar (x, me) -> SFpvar  (x, me)
  | Fglob (m, me) -> SFglob  (m, me)

  | Fif    (c, f1, f2)  -> SFif    (c, f1, f2)
  | Fmatch (b, fs, ty)  -> SFmatch (b, fs, ty)
  | Flet   (lv, f1, f2) -> SFlet   (lv, f1, f2)
  | Ftuple fs           -> SFtuple fs
  | Fproj (f, i)        -> SFproj  (f,i)

  | Fquant (_, [ ]  , f) -> sform_of_form f
  | Fquant (q, [b]  , f) -> SFquant (q, b, lazy f)
  | Fquant (q, b::bs, f) -> SFquant (q, b, lazy (f_quant q bs f))

  | FhoareF  hf -> SFhoareF  hf
  | FhoareS  hs -> SFhoareS  hs
  | FcHoareF  hf -> SFcHoareF  hf
  | FcHoareS  hs -> SFcHoareS  hs
  | FbdHoareF hf -> SFbdHoareF hf
  | FbdHoareS hs -> SFbdHoareS hs
  | FequivF   ef -> SFequivF   ef
  | FequivS   es -> SFequivS   es
  | Fpr       pr -> SFpr       pr

  | Fop (op, ty) ->
      sform_of_op (op, ty) []

  | Fapp ({ f_node = Fop (op, ty) }, args) ->
      sform_of_op (op, ty) args

  | Fcost c     -> SFcost c
  | Fmodcost mc -> SFmodcost mc

  | _ -> SFother fp


(* -------------------------------------------------------------------- *)
let int_of_form =
  let module E = struct exception NotAConstant end in

  let rec doit f =
    match sform_of_form f with
    | SFint x ->
        x

    | SFop ((op, []), [a]) when op_kind op = Some `Int_opp ->
        BI.neg (doit a)

    | SFop ((op, []), [a1; a2]) -> begin
        match op_kind op with
        | Some `Int_add -> BI.add (doit a1) (doit a2)
        | Some `Int_mul -> BI.mul (doit a1) (doit a2)
        | _ -> raise E.NotAConstant
      end

    | _ -> raise E.NotAConstant

  in fun f -> try Some (doit f) with E.NotAConstant -> None

let real_of_form f =
  match sform_of_form f with
  | SFop ((op, []), [a]) ->
      if   EcPath.p_equal op CI.CI_Real.p_real_of_int
      then int_of_form a
      else None
  | _ -> None

(* [x] of type [txint]. *)
let decompose_N x =
  match destr_app x with
  | { f_node = Fop (p, _) }, [f]
    when EcPath.p_equal p EcCoreLib.CI_Xint.p_N   -> Some f
  | _ -> None


(* -------------------------------------------------------------------- *)
let f_int_le_simpl f1 f2 =
  if f_equal f1 f2 then f_true else

  match opair int_of_form f1 f2 with
  | Some (x1, x2) -> f_bool (BI.compare x1 x2 <= 0)
  | None -> f_int_le f1 f2

let f_int_lt_simpl f1 f2 =
  if f_equal f1 f2 then f_false else

  match opair int_of_form f1 f2 with
  | Some (x1, x2) -> f_bool (BI.compare x1 x2 < 0)
  | None -> f_int_lt f1 f2

let f_real_le_simpl f1 f2 =
  if f_equal f1 f2 then f_true else

  match opair real_of_form f1 f2 with
  | Some (x1, x2) -> f_bool (BI.compare x1 x2 <= 0)
  | _ -> f_real_le f1 f2

let f_real_lt_simpl f1 f2 =
  if f_equal f1 f2 then f_false else

  match opair real_of_form f1 f2 with
  | Some (x1, x2) -> f_bool (BI.compare x1 x2 < 0)
  | _ -> f_real_lt f1 f2

(* -------------------------------------------------------------------- *)
let f_xle_simpl (c1 : form) (c2 : form) : form =
  match decompose_N c1, decompose_N c2 with
  | Some c1, Some c2 -> f_int_le_simpl c1 c2
  | _                -> f_xle c1 c2

let f_xlt_simpl (c1 : form) (c2 : form) : form =
  match decompose_N c1, decompose_N c2 with
  | Some c1, Some c2 -> f_int_lt_simpl c1 c2
  | _                -> f_xlt c1 c2

let f_xadd_simpl (c1 : form) (c2 : form) : form =
  match decompose_N c1, decompose_N c2 with
  | Some c1, Some c2 -> f_N (f_int_add_simpl c1 c2)
  | _                -> f_xadd c1 c2

let f_xmul_simpl (c1 : form) (c2 : form) : form =
  match decompose_N c1, decompose_N c2 with
  | Some c1, Some c2 -> f_N (f_int_mul_simpl c1 c2)
  | _                -> f_xmul c1 c2

let f_xmax_simpl (c1 : form) (c2 : form) : form =
  match decompose_N c1, decompose_N c2 with
  | Some c1, Some c2 -> f_N (f_int_max_simpl c1 c2)
  | _                -> f_xmax c1 c2

let f_xopp_simpl (c : form) : form =
  match decompose_N c with
  | Some c -> f_N (f_int_opp_simpl c)
  | _      -> f_xopp c

let f_is_inf_simpl (c : form) : form =
  if is_inf c then f_true else f_is_inf c

let f_is_int_simpl (c : form) : form =
  if is_inf c then f_false else f_is_int c

(* -------------------------------------------------------------------- *)

(** Simplification of cost equality and inequality tests using
    module freshness and epochs. *)
module CostCompSimplify = struct
  type cproj =
    | PFresh of EcPath.xpath * Epoch.t
    (* proj. over a procedure of a [Fresh] module with its epoch *)

    | AllExcept of EcPath.xpath list * Epoch.t
    (* procedures and concrete cost except some (already projected) procedures
       of [Fresh] modules, and the minimum epoch of all these [Fresh] modules. *)


  (* replace arithmetic operations by their counterpart after projection *)
  let cproj_op (p : cproj) (f : form) : form =
    match p with
    | AllExcept _ -> f
    | PFresh _ ->
      snd @@
      List.find (f_equal f |- fst)
        [ (fop_cost_add    , f_op_xadd);
          (fop_cost_opp    , f_op_xopp);
          (fop_cost_scale  , f_op_xmuli);
          (fop_cost_xscale , f_op_xmul); ]

  (* replace comparison operations by their counterpart after projection *)
  let cproj_cmp (p : cproj) (f : form) : form =
    match p with
    | AllExcept _ -> f
    | PFresh _ ->
      snd @@
      List.find (f_equal f |- fst)
        [ (fop_cost_le, f_op_xle );
          (fop_cost_le, f_op_xlt );
          (fop_eq EcTypes.tcost, fop_eq EcTypes.txint); ]

  (* built the list of projections resulting from some local hyps *)
  let mk_cprojs (hyps : EcEnv.LDecl.hyps) : cproj list =
    let env = EcEnv.LDecl.toenv hyps in

    let locals = (EcEnv.LDecl.tohyps hyps).h_local in
    let fresh_mts =
      List.filter_map (fun { l_id; l_kind; l_epoch = e } ->
          match l_kind with
          | LD_modty (Fresh, mt) -> Some (l_id, mt, e)
          | _ -> None
        ) locals
    in

    (* arbitrary epoch in [fresh_mts]. Can be anything if [fresh_mts] is empty. *)
    let e = match fresh_mts with
      | [] -> Epoch.init
      | (_, _, e) :: _ -> e
    in

    let min_epoch, procs =
      List.fold_left_map (fun e (mid, mt, e') ->
          let procs =
            List.map (fun (EcModules.Tys_function fs) ->
                let xp = EcPath.xpath (EcPath.mident mid) fs.fs_name in
                xp, e'
              )
              (EcEnv.ModTy.sig_of_mt env mt).EcModules.mis_body
          in
          Epoch.min e e', procs
        ) e fresh_mts
    in
    let procs = List.flatten procs in

    let allxp = List.map fst procs in
    let projs = List.map (fun (xp, e) -> PFresh (xp, e)) procs in

    AllExcept (allxp, min_epoch) :: projs


  exception SFail

  (* check that some formula occurs strictly before some epoch *)
  let check_before (hyps : EcEnv.LDecl.hyps) (etop : Epoch.t) (f : form) : unit =
    let rec check f =
      match f.f_node with
      | Flocal l ->
        begin
          match by_id_opt l (EcEnv.LDecl.tohyps hyps) with
          | None -> raise SFail
          | Some { l_epoch = e } ->
            if not (Epoch.lt e etop) then raise SFail
        end
      | Fapp (f, fs) -> List.iter check (f :: fs)
      | Flet (_, f, f') -> check_l [f; f']
      | Fglob _
      | Fint _
      | Fop _ -> ()
      | Fif (f,f1,f2) -> check_l [f; f1; f2]
      | Ftuple l -> check_l l
      | Fproj (f, _) -> check f
      | _ -> raise SFail

    and check_l l = List.iter check l in

    check f

  (* Try to simplify the projection.
     - [f] has type [tcost]
     - return: type [tcost] if [p = AllExceptFresh], [txint] otherwise
     - raise [SFail] if the simplification failed *)
  let simpl_cproj
      (hyps : EcEnv.LDecl.hyps) (p : cproj) (f : form)
    : form
    =
    let rec simpl (f : form) : form =
      match f.f_node with
      | Fcost c ->
        begin match p with
          | AllExcept (xps,_) ->
            let calls = EcPath.Mx.filter (fun xp _ ->
                not (List.exists (EcPath.x_equal xp) xps)
              ) c.c_calls
            in
            f_cost_r (cost_r c.c_self calls c.c_full)

          | PFresh (xp', _) ->
            oget_c_bnd (EcPath.Mx.find_opt xp' c.c_calls) c.c_full
        end

      | Fapp (f_op, lf)
        when List.exists (f_equal f_op) [fop_cost_add; fop_cost_opp] ->
        f_app (cproj_op p f_op) (List.map simpl lf) f.f_ty

      | Fapp (f_op, [scale; fc]) when
          List.exists (f_equal f_op) [fop_cost_scale; fop_cost_xscale] ->
        f_app (cproj_op p f_op) [scale; simpl fc] f.f_ty

      | Flocal l ->
        begin
          match by_id_opt l (EcEnv.LDecl.tohyps hyps) with
          | None -> raise SFail
          | Some { l_epoch = e } ->
            match p with
            | PFresh (_, e') ->
              if Epoch.lt e e' then f_x0 else raise SFail

            | AllExcept (_,e') ->
              if Epoch.lt e e' then f else raise SFail
        end

      | Fop _ ->
        begin match p with
          | PFresh    _ -> f_x0
          | AllExcept _ -> f
        end

      | _ ->
        match p with
        | PFresh _ -> f_x0
        | AllExcept (_, etop) ->
          check_before hyps etop f;
          f
    in
    simpl f

  let simpl (hyps : EcEnv.LDecl.hyps) (f : form) : form =
    match f.f_node with
    | Fapp (fop, [f1; f2])
      when List.exists (f_equal fop)
          [fop_cost_le; fop_cost_le; fop_eq EcTypes.tcost; ] ->
      let cprojs = mk_cprojs hyps in

      begin try
          let forms =
            List.map (fun cp ->
                let f1' = simpl_cproj hyps cp f1 in
                let f2' = simpl_cproj hyps cp f2 in
                let fop' = cproj_cmp cp fop in

                let f' = f_app fop' [f1'; f2'] tbool in

                if f_equal f f' then raise SFail;

                f'
              ) cprojs
          in
           f_ands0_simpl forms
        with SFail -> f           (* simplification failed, we do nothing *)
      end

    | _ -> f
end

(* -------------------------------------------------------------------- *)
(* lift a unary function to [tcost] *)
let f_cost_map
    (xf    : form -> form)      (* type [txint -> txint] *)
    (costf : form -> form)      (* type [tcost -> tcost] *)
    (c     : form)              (* type [tcost] *)
  : form                        (* type [tcost] *)
  =
  if not (is_cost c) then costf c
  else
    let c = destr_cost c in
    let self = xf c.c_self in
    let calls = EcPath.Mx.map (fun x -> xf x) c.c_calls in
    f_cost_r (cost_r self calls c.c_full)

let f_cost_opp_simpl =
  f_cost_map
    (fun x -> f_xopp_simpl x)
    (fun c -> f_cost_opp c)

let f_cost_scale_simpl (f : form) (c : form) =
  if      f_equal f f_i0 then f_cost_zero
  else if f_equal f f_i1 then c
  else
    f_cost_map
      (fun x -> f_xmul_simpl (f_N f) x)
      (fun c -> f_cost_scale f c)
      c

let f_cost_xscale_simpl (f : form) (c : form) =
  if      f_equal f f_x0  then f_cost_zero
  else if f_equal f f_x1  then c
  else if f_equal f f_Inf then f_cost_inf
  else
    f_cost_map
      (fun x -> f_xmul_simpl f x)
      (fun c -> f_cost_xscale f c)
      c

(* -------------------------------------------------------------------- *)
(* Lift a unary function over [args -> txint] to [args -> tcost]
   where [args] is [a_1 -> ... -> a_n].
   I.e. commutes a λ-binding and the cost record. *)
let f_lam_cost_map
    (xf    : form -> form)      (* type [(args -> txint) -> txint] *)
    (costf : form -> form)      (* type [(args -> tcost) -> tcost] *)
    (c     : form)              (* type [args -> tcost] *)
  : form                        (* type [tcost] *)
  =
  let bd, body = decompose_lambda c in
  if not (is_cost body) then costf c
  else
    let body = destr_cost body in
    let self = xf (f_lambda bd body.c_self) in
    let calls = EcPath.Mx.map (fun x -> xf (f_lambda bd x)) body.c_calls in
    f_cost_r (cost_r self calls body.c_full)

let f_bigcost_simpl (pred : form) (cost : form) (l : form) : form =
  f_lam_cost_map
    (fun x -> f_bigx pred x l)
    (fun c -> f_bigcost pred c l)
    cost
(* -------------------------------------------------------------------- *)
let cost_is_zero (c : form) : bool =
  if not (is_cost c) then false
  else
    let c = destr_cost c in
    c.c_full &&
    f_equal f_x0 c.c_self &&
    EcPath.Mx.for_all (fun _ -> f_equal f_x0) c.c_calls

(* -------------------------------------------------------------------- *)
(* lift a binary operator over [txint] to [tcost] *)
let f_cost_mk_bin_simpl xop costop (c1 : form) (c2 : form) : form =
  if not (is_cost c1 && is_cost c2) then
    costop c1 c2
  else
    let c1, c2 = destr_cost c1, destr_cost c2 in

    let self = xop c1.c_self c2.c_self in
    let calls =
      EcPath.Mx.merge (fun _ x1 x2 ->
          let x1 = oget_c_bnd x1 c1.c_full
          and x2 = oget_c_bnd x2 c2.c_full in
          Some (xop x1 x2)
        ) c1.c_calls c2.c_calls
    in
    f_cost_r (cost_r self calls (c1.c_full && c2.c_full))

let f_cost_add_simpl c1 c2 =
  if cost_is_zero c1 then c2 else
  if cost_is_zero c2 then c1 else
    f_cost_mk_bin_simpl f_xadd_simpl f_cost_add c1 c2

(* lift a binary comparison over [txint] to [tcost] *)
let f_cost_mk_cmp fullcmp xcmp costcmp (c1 : form) (c2 : form) : form =
  if not (is_cost c1 && is_cost c2) then
    costcmp c1 c2
  else
    let c1, c2 = destr_cost c1, destr_cost c2 in
    cost_mk_cmp fullcmp xcmp c1 c2

let f_cost_le_simpl (hyps : EcEnv.LDecl.hyps) f f' =
  if f_equal f' f_cost_inf || f_equal f' f_cost_inf0 then f_true
  else
    let mk_le f f' =
      CostCompSimplify.simpl hyps (f_cost_le f f')
    in
    f_cost_mk_cmp (fun b b' -> not b' || b = b') f_xle_simpl mk_le f f'

let f_cost_lt_simpl (hyps : EcEnv.LDecl.hyps) f f' =
  let mk_lt f f' =
    CostCompSimplify.simpl hyps (f_cost_lt f f')
  in
  f_cost_mk_cmp (fun b b' -> b' || b = b') f_xlt_simpl mk_lt f f'

let f_cost_is_int_simpl c =
  if not (is_cost c) then f_cost_is_int c
  else
    let c = destr_cost c in
    if c.c_full = false then f_false
    else
      let self = f_is_int_simpl c.c_self in
      let calls =
        List.map (fun (_, x) -> f_is_int_simpl x) (EcPath.Mx.bindings c.c_calls)
      in
      f_ands0_simpl (self :: calls)

(* -------------------------------------------------------------------- *)
let mod_cost_proj_simpl (mc : mod_cost) (p : cost_proj) : form =
  match p with
  | Intr fname ->
    let pcost = Msym.find fname mc in (* cannot fail *)
    pcost.c_self

  | Param {proc = fname; param_m; param_p } ->
    let pcost = Msym.find fname mc in (* cannot fail *)

    let c = EcPath.Mx.find_fun_opt (fun xp _ ->
        EcIdent.name (EcPath.mget_ident xp.x_top) = param_m &&
        xp.x_sub = param_p
      ) pcost.c_calls
    in

    oget_c_bnd c pcost.c_full

let f_cost_proj_simpl (f : form) (p : cost_proj) : form =
  match f.f_node with
  | Fmodcost mc -> mod_cost_proj_simpl mc p
  | _ -> f_cost_proj_r f p

(* -------------------------------------------------------------------- *)
(* destr_exists_prenex destructs recursively existentials in a formula
 *  whenever possible.
 * For instance:
 * - E x p1 /\ E y p2 -> [x,y] (p1 /\ p2)
 * - E x p1 /\ E x p2 -> [] (E x p1 /\ E x p2)
 * - p1 => E x p2 -> [x] (p1 => p2)
 * - E x p1 => p2 -> [] (E x p1 => p2)
 *)
let destr_exists_prenex f =
  let disjoint bds1 bds2 =
    List.for_all
      (fun (id1, _) -> List.for_all (fun (id2, _) -> id1 <> id2) bds2)
      bds1
  in

  let rec prenex_exists bds p =
    match sform_of_form p with
    | SFand (`Sym, (f1, f2)) ->
        let (bds1, f1) = prenex_exists [] f1 in
        let (bds2, f2) = prenex_exists [] f2 in
          if   disjoint bds1 bds2
          then (bds1@bds2@bds, f_and f1 f2)
          else (bds, p)

    | SFor (`Sym, (f1, f2)) ->
        let (bds1, f1) = prenex_exists [] f1 in
        let (bds2, f2) = prenex_exists [] f2 in
          if   disjoint bds1 bds2
          then (bds1@bds2@bds, f_or f1 f2)
          else (bds, p)

    | SFimp (f1, f2) ->
        let (bds2, f2) = prenex_exists bds f2 in
          (bds2@bds, f_imp f1 f2)

    | SFquant (Lexists, bd, lazy p) ->
        let (bds, p) = prenex_exists bds p in
          (bd::bds, p)

    | SFif (f, ft, fe) ->
        let (bds1, f1) = prenex_exists [] ft in
        let (bds2, f2) = prenex_exists [] fe in
          if   disjoint bds1 bds2
          then (bds1@bds2@bds, f_if f f1 f2)
          else (bds, p)

    | _ -> (bds, p)
  in
    (* Make it fail as with destr_exists *)
    match prenex_exists [] f with
    | [] , _ -> destr_error "exists"
    | bds, f -> (bds, f)

(* -------------------------------------------------------------------- *)
let destr_ands ~deep =
  let rec doit f =
    try
      let (f1, f2) = destr_and f in
      (if deep then doit f1 else [f1]) @ (doit f2)
    with DestrError _ -> [f]

  in fun f -> doit f
back to top