https://github.com/EasyCrypt/easycrypt
Raw File
Tip revision: 30bfa950afa3806948c073d3c9ec4468d33ea940 authored by Pierre-Yves Strub on 11 December 2023, 10:58:49 UTC
New tactic: "proc change"
Tip revision: 30bfa95
ecProcSem.ml
(* -------------------------------------------------------------------- *)
open EcUtils
open EcSymbols
open EcAst
open EcTypes
open EcModules
open EcFol

module BI = EcBigInt

(* -------------------------------------------------------------------- *)
exception SemNotSupported

(* -------------------------------------------------------------------- *)
type senv = {
  env   : EcEnv.env;
  subst : EcIdent.t Msym.t;
}

(* -------------------------------------------------------------------- *)
module Env = struct
  let empty (env : EcEnv.env) =
    { env; subst = Msym.empty; }

  let fresh (env : senv) (x : symbol) =
    let idx = EcIdent.create x in
    let env = { env with subst = Msym.add x idx env.subst } in
    (env, idx)
end

(* -------------------------------------------------------------------- *)
type mode = [`Det | `Distr]

(* -------------------------------------------------------------------- *)
(* FIXME: MOVE ME                                                       *)
let eop_dunit (ty : ty) =
  e_op EcCoreLib.CI_Distr.p_dunit [ty] (tfun ty (tdistr ty))

let e_dunit (e : expr) =
  e_app (eop_dunit e.e_ty) [e] (tdistr e.e_ty)

(* -------------------------------------------------------------------- *)
let rec translate_i (env : senv) (cont : senv -> mode * expr) (i : instr) =
  EcPV.PV.iter
    (fun pv _ -> if not (is_loc pv) then raise SemNotSupported)
    (fun _ -> raise SemNotSupported)
    (EcPV.i_read env.env i);

  let wr =
    let do1 (pv, ty) =
      match pv with
      | PVglob _ -> raise SemNotSupported
      | PVloc  x -> (x, ty) in

    let wr, mods = EcPV.PV.elements (EcPV.i_write env.env i) in

    if not (List.is_empty mods) then
      raise SemNotSupported;
    List.map do1 wr
  in

  let env', ids =
    List.fold_left_map
      (fun env (x, _) -> Env.fresh env x)
      env wr in

  let ids = List.combine wr ids in

  match i.i_node with
  | Sasgn (lv, e) ->
     let e = translate_e env e in
     let lv = translate_lv env' lv in
     let mode, body = cont env' in
     (mode, (e_let lv e body))

  | Srnd (lv, d) -> begin
     let d = translate_e env d in
     let lv = translate_lv env' lv in
     let mode, body = cont env' in

     let tya = oget (as_tdistr (EcEnv.Ty.hnorm d.e_ty env.env)) in
     let tyb = body.e_ty in

     let aout =
       let d    = form_of_expr mhr d in
       let body = form_of_expr mhr body in
       let body =
         let arg  = EcIdent.create "arg" in
         let body = f_let lv (f_local arg tya) body in
         f_lambda [(arg, GTty tya)] body in

       match mode with
       | `Det   -> f_dmap tya tyb d body
       | `Distr -> f_dlet_simpl tya (oget (as_tdistr tyb)) d body

     in (`Distr, expr_of_form mhr aout)
    end

  | Sif (e, bt, bf) ->
     let cont (fenv : senv) : mode * expr =
       let do1 ((x, ty), _) =
         e_local (Msym.find x fenv.subst) ty in
       let vars = List.map do1 ids in
       (`Det, e_tuple vars) in

     let e  = translate_e env e in
     let bt = translate_s env cont bt in
     let bf = translate_s env cont bf in

     let mode, (bt, bf) =
       match bt, bf with
       | (`Det, bt), (`Det, bf) ->
          (`Det, (bt, bf))

       | (`Distr, bt), (`Distr, bf) ->
          (`Distr, (bt, bf))

       | (`Det, bt), (`Distr, bf) ->
          (`Distr, (e_dunit bt, bf))

       | (`Distr, bt), (`Det, bf) ->
          (`Distr, (bt, e_dunit bf)) in

     let lv =
       let ids =
         let do1 ((x, ty), _) =
           (Msym.find x env'.subst, ty) in
         List.map do1 ids in

       match ids with
       | [] ->
          LSymbol (EcIdent.create "_", tunit)
       | [x, ty] ->
          LSymbol (x, ty)
       | ids ->
          LTuple ids in

     let cmode, c = (cont env') in

     begin
       match mode, cmode with
       | `Det, _ ->
          (cmode, e_let lv (e_if e bt bf) c)

       | `Distr, `Det ->
          let body = form_of_expr mhr (e_if e bt bf) in
          let tya  = oget (as_tdistr body.f_ty) in
          let v    = EcIdent.create "v" in
          let vx   = f_local v tya in
          let aout =
            f_dmap
              tya
              c.e_ty
              body
              (f_lambda
                 [v, GTty tya]
                 (f_let lv vx (form_of_expr mhr c)))

          in (`Distr, expr_of_form mhr aout)

       | `Distr, `Distr ->
          let body = form_of_expr mhr (e_if e bt bf) in
          let tya  = oget (as_tdistr body.f_ty) in
          let tyb  = oget (as_tdistr c.e_ty) in
          let v    = EcIdent.create "v" in
          let vx   = f_local v tya in
          let aout =
            f_dlet_simpl
              tya
              tyb
              body
              (f_lambda
                 [v, GTty tya]
                 (f_let lv vx (form_of_expr mhr c)))

          in (`Distr, expr_of_form mhr aout)

     end

  | Scall (Some lv, ({ x_top = { m_top = `Concrete (p, _) }; x_sub = f } as xp), args)  ->
      let fd   = oget (EcEnv.Fun.by_xpath_opt xp env.env) in
      let args = translate_e env (e_tuple args) in
      let op   = EcPath.pqname (oget (EcPath.prefix p)) f in
      let op   = e_op op [] (tfun fd.f_sig.fs_arg fd.f_sig.fs_ret) in
      let op   = e_app op [args] fd.f_sig.fs_ret in
      let lv   = translate_lv env' lv  in

      let cmode, c = cont env' in

      (cmode, e_let lv op c)

  | Swhile    _
  | Smatch    _
  | Sassert   _
  | Sabstract _
  | Scall     _ ->
     raise SemNotSupported;

(* -------------------------------------------------------------------- *)
and translate_s (env : senv) (cont : senv -> mode * expr) (s : stmt) =
  match translate_forloop env cont s with
  | Some e ->
     e
  | None ->
     match s.s_node with
     | [] ->
        cont env
     | i :: s ->
        translate_i env (fun env -> translate_s env cont (stmt s)) i

(* -------------------------------------------------------------------- *)
and translate_forloop (env : senv) (cont : senv -> mode * expr) (s : stmt) =
  let module ET = EcReduction.EqTest in

  match s.s_node with
  | { i_node = Sasgn (LvVar (PVloc x, xty), e) } :: { i_node = Swhile (c, body) } :: s_tail ->
     if not (ET.for_type env.env xty tint) then
       raise SemNotSupported;

     if not (ET.for_expr env.env e (e_int EcBigInt.zero)) then
       raise SemNotSupported;

     let inc, body =
       let inc, body =
         match List.rev body.s_node with
         | inc :: body -> inc, List.rev body
         | _ -> raise SemNotSupported in

       match inc.i_node with
       | Sasgn (LvVar (PVloc y, _), ic) ->
          if x <> y then
            raise SemNotSupported
          else begin
            match ic.e_node with
            | Eapp ({ e_node = Eop (op, []) }, [{ e_node = Evar (PVloc y') }; { e_node = Eint inc }])
                 when    y = y'
                      && EcBigInt.lt EcBigInt.zero inc
                      && EcPath.p_equal op EcCoreLib.CI_Int.p_int_add
              -> inc, body
            | _ -> raise SemNotSupported
          end;
       | _ -> raise SemNotSupported in

     let body =
       if BI.gt inc BI.one then begin
         let mx =
           e_app
             (e_op EcCoreLib.CI_Int.p_int_mul [] (toarrow [tint; tint] tint))
             [e_int inc; e_var (pv_loc x) tint] tint in
         let subst = EcPV.Mpv.add env.env (pv_loc x) mx EcPV.Mpv.empty in
         EcPV.Mpv.issubst env.env subst body
       end else body in

     let bd =
       match c.e_node with
       | Eapp ({ e_node = Eop (op, []) }, [{ e_node = Evar (PVloc y) }; bd])
            when    x = y
                 && EcPath.p_equal op EcCoreLib.CI_Int.p_int_lt -> bd
       | _ -> raise SemNotSupported in

     let wr = EcPV.s_write env.env (EcModules.stmt body) in

     if EcPV.PV.mem_pv env.env (pv_loc x) wr then
       raise SemNotSupported;

     if not (EcPV.PV.indep env.env (EcPV.e_read env.env bd) wr) then
       raise SemNotSupported;

     EcPV.PV.iter
       (fun pv _ -> if not (is_loc pv) then raise SemNotSupported)
       (fun _ -> raise SemNotSupported)
       (EcPV.is_read env.env body);

     let wr =
       let do1 (pv, ty) =
         match pv with
         | PVglob _ -> raise SemNotSupported
         | PVloc  z -> (z, ty) in

       let wr, mods = EcPV.PV.elements (EcPV.is_write env.env body) in

       if not (List.is_empty mods) then
         raise SemNotSupported;
       List.map do1 wr
     in

     let wr = List.filter (fun (z, _) -> z <> x) wr in

     let mode, body, _ =
       let env', ids =
         List.fold_left_map
           (fun env (x, _) -> Env.fresh env x)
           env wr in

       let ids = List.combine wr ids in

       let env', x = Env.fresh env' x in

       let cont_body (fenv : senv) : mode * expr =
         let do1 ((x, ty), _) =
           e_local (Msym.find x fenv.subst) ty in
         let vars = List.map do1 ids in
         (`Det, e_tuple vars) in

       let bmode, body = translate_s env' cont_body (stmt body) in

       let body =
         match ids with
         | [] ->
            e_lam [(EcIdent.create "_", tunit)] body
         | [((_, ty), z)] ->
            e_lam [(z, ty)] body
         | ids ->
            let arg = EcIdent.create "arg" in
            let aty = ttuple (List.map (fun ((_, ty), _) -> ty) ids) in
            let lv  = LTuple (List.map (fun ((_, ty), z) -> (z, ty)) ids) in
            e_lam
              [(arg, aty)]
              (e_let lv (e_local arg aty) body) in

       let body = e_lam [(x, tint)] body in

       bmode, body, ids in

     let env', ids =
       List.fold_left_map
         (fun env (x, _) -> Env.fresh env x)
         env wr in

     let ids = List.combine wr ids in
     let aty = ttuple (List.map (fun ((_, ty), _) -> ty) ids) in

     let env', x = Env.fresh env' x in

     let lv =
       let ids =
         let do1 ((x, ty), _) =
           (Msym.find x env'.subst, ty) in
         List.map do1 ids in

       match ids with
       | [] ->
          LSymbol (EcIdent.create "_", tunit)
       | [x, ty] ->
          LSymbol (x, ty)
       | ids ->
          LTuple ids in

     let niter = form_of_expr mhr (translate_e env bd) in
     let niter = f_proj_simpl (f_int_edivz_simpl niter (f_int inc)) 0 tint in
     let rem   = f_proj_simpl (f_int_edivz_simpl niter (f_int inc)) 1 tint in
     let outv  = f_int_add_simpl (f_int_mul_simpl niter (f_int inc)) rem in

     let niter = expr_of_form mhr niter in
     let outv  = expr_of_form mhr outv in

     let mode, aout =
       match mode with
       | `Det ->
          let args =
            List.map
              (fun (z, zty) ->
                match Msym.find_opt z env.subst with
                | None -> e_op EcCoreLib.CI_Witness.p_witness [zty] zty
                | Some z -> e_local z zty)
              wr in
          let args = e_tuple args in
          let cmode, c = translate_s env' cont (stmt s_tail) in
          let aout = e_op EcCoreLib.CI_Int.p_iteri [aty] in
          let aout = aout (toarrow [tint; (toarrow [tint; aty] aty); aty] aty) in
          let aout = e_app aout [niter; body; args] aty in
          (cmode, e_let lv aout c)

       | `Distr ->
          let args =
            List.map
              (fun (z, zty) ->
                match Msym.find_opt z env.subst with
                | None -> e_op EcCoreLib.CI_Witness.p_witness [zty] zty
                | Some z -> e_local z zty)
              wr in
          let args = e_tuple args in
          let cmode, c = translate_s env' cont (stmt s_tail) in
          let aout = e_op EcCoreLib.CI_Distr.p_dfold [aty] in
          let aout = aout (toarrow [toarrow [tint; aty] (tdistr aty); aty; tint] (tdistr aty)) in
          let aout = e_app aout [body; args; niter] (tdistr aty) in

          let arg = EcIdent.create "arg" in

          let ctor =
            match cmode with
            | `Det   -> f_dmap
            | `Distr -> f_dlet_simpl in

          let aout =
            ctor
              aty c.e_ty
              (form_of_expr mhr aout)
              (f_lambda
                 [(arg, GTty aty)]
                 (f_let lv (f_local arg aty) (form_of_expr mhr c))) in
          (`Distr, expr_of_form mhr aout)

     in Some (mode, e_let (LSymbol (x, tint)) outv aout)

  | _ ->
     None

(* -------------------------------------------------------------------- *)
and translate_e (env : senv) (e : expr) =
  match e.e_node with
  | Evar (PVloc x) ->
     e_local (oget (Msym.find_opt x env.subst)) e.e_ty

  | Evar (PVglob _) ->
     raise SemNotSupported

  | _ ->
     e_map (fun x -> x) (translate_e env) e

(* -------------------------------------------------------------------- *)
and translate_lv (env : senv) (lv : lvalue) : lpattern =
  match lv with
  | LvVar (pv, ty) ->
     LSymbol (translate_pv env pv, ty)

  | LvTuple pvs ->
     let do1 (pv, ty) =
       (translate_pv env pv, ty)
     in LTuple (List.map do1 pvs)

(* -------------------------------------------------------------------- *)
and translate_pv (env : senv) (pv : prog_var) =
  match pv with
  | PVglob _ ->
     raise SemNotSupported
  | PVloc x ->
      oget (Msym.find_opt x env.subst)
back to top