ecReduction.ml
(* --------------------------------------------------------------------
* Copyright (c) - 2012--2016 - IMDEA Software Institute
* Copyright (c) - 2012--2018 - Inria
* Copyright (c) - 2012--2018 - Ecole Polytechnique
*
* Distributed under the terms of the CeCILL-C-V1 license
* -------------------------------------------------------------------- *)
(* -------------------------------------------------------------------- *)
open EcUtils
open EcIdent
open EcPath
open EcTypes
open EcModules
open EcFol
open EcEnv
module BI = EcBigInt
(* -------------------------------------------------------------------- *)
exception IncompatibleType of env * (ty * ty)
exception IncompatibleForm of env * (form * form)
exception IncompatibleExpr of env * (expr * expr)
(* -------------------------------------------------------------------- *)
type 'a eqtest = env -> 'a -> 'a -> bool
type 'a eqntest = env -> ?norm:bool -> 'a -> 'a -> bool
module EqTest = struct
let rec for_type env t1 t2 =
ty_equal t1 t2 || for_type_r env t1 t2
and for_type_r env t1 t2 =
match t1.ty_node, t2.ty_node with
| Tunivar uid1, Tunivar uid2 -> EcUid.uid_equal uid1 uid2
| Tvar i1, Tvar i2 -> i1 = i2
| Ttuple lt1, Ttuple lt2 ->
List.length lt1 = List.length lt2
&& List.all2 (for_type env) lt1 lt2
| Tfun (t1, t2), Tfun (t1', t2') ->
for_type env t1 t1' && for_type env t2 t2'
| Tglob mp, _ when EcEnv.NormMp.tglob_reducible env mp ->
for_type env (EcEnv.NormMp.norm_tglob env mp) t2
| _, Tglob mp when EcEnv.NormMp.tglob_reducible env mp ->
for_type env t1 (EcEnv.NormMp.norm_tglob env mp)
| Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 ->
if
List.length lt1 = List.length lt2
&& List.all2 (for_type env) lt1 lt2
then true
else
if Ty.defined p1 env
then for_type env (Ty.unfold p1 lt1 env) (Ty.unfold p2 lt2 env)
else false
| Tconstr(p1,lt1), _ when Ty.defined p1 env ->
for_type env (Ty.unfold p1 lt1 env) t2
| _, Tconstr(p2,lt2) when Ty.defined p2 env ->
for_type env t1 (Ty.unfold p2 lt2 env)
| _, _ -> false
(* ------------------------------------------------------------------ *)
let is_unit env ty = for_type env tunit ty
let is_bool env ty = for_type env tbool ty
let is_int env ty = for_type env tint ty
(* ------------------------------------------------------------------ *)
let for_type_exn env t1 t2 =
if not (for_type env t1 t2) then
raise (IncompatibleType (env, (t1, t2)))
(* ------------------------------------------------------------------ *)
let for_pv env ~norm p1 p2 =
pv_equal p1 p2 || (norm && (pv_kind p1 = pv_kind p2) &&
let p1 = NormMp.norm_pvar env p1 in
let p2 = NormMp.norm_pvar env p2 in
pv_equal p1 p2)
(* ------------------------------------------------------------------ *)
let for_xp env ~norm p1 p2 =
EcPath.x_equal p1 p2 || (norm &&
let p1 = NormMp.norm_xfun env p1 in
let p2 = NormMp.norm_xfun env p2 in
EcPath.x_equal p1 p2)
(* ------------------------------------------------------------------ *)
let for_mp env ~norm p1 p2 =
EcPath.m_equal p1 p2 || (norm &&
let p1 = NormMp.norm_mpath env p1 in
let p2 = NormMp.norm_mpath env p2 in
EcPath.m_equal p1 p2)
(* ------------------------------------------------------------------ *)
let for_expr env ~norm =
let module E = struct exception NotConv end in
let find alpha id = odfl id (Mid.find_opt id alpha) in
let noconv (f : expr -> expr -> bool) e1 e2 =
try f e1 e2 with E.NotConv -> false in
let check_binding env alpha (id1, ty1) (id2, ty2) =
if not (for_type env ty1 ty2) then
raise E.NotConv;
Mid.add id1 id2 alpha in
let check_bindings env alpha b1 b2 =
if List.length b1 <> List.length b2 then
raise E.NotConv;
List.fold_left2 (check_binding env) alpha b1 b2 in
let check_lpattern alpha lp1 lp2 =
match lp1, lp2 with
| LSymbol (id1,_), LSymbol (id2,_) ->
Mid.add id1 id2 alpha
| LTuple lid1, LTuple lid2 when List.length lid1 = List.length lid2 ->
List.fold_left2
(fun alpha (id1,_) (id2,_) -> Mid.add id1 id2 alpha)
alpha lid1 lid2
| _, _ -> raise E.NotConv in
let rec aux alpha e1 e2 =
e_equal e1 e2 || aux_r alpha e1 e2
and aux_r alpha e1 e2 =
match e1.e_node, e2.e_node with
| Eint i1, Eint i2 ->
BI.equal i1 i2
| Elocal id1, Elocal id2 ->
EcIdent.id_equal (find alpha id1) id2
| Evar p1, Evar p2 ->
for_pv env ~norm p1 p2
| Eop(o1,ty1), Eop(o2,ty2) ->
p_equal o1 o2 && List.all2 (for_type env) ty1 ty2
| Equant(q1,b1,e1), Equant(q2,b2,e2) when qt_equal q1 q2 ->
let alpha = check_bindings env alpha b1 b2 in
noconv (aux alpha) e1 e2
| Eapp (f1, args1), Eapp (f2, args2) ->
aux alpha f1 f2 && List.all2 (aux alpha) args1 args2
| Elet (p1, f1', g1), Elet (p2, f2', g2) ->
aux alpha f1' f2'
&& noconv (aux (check_lpattern alpha p1 p2)) g1 g2
| Etuple args1, Etuple args2 -> List.all2 (aux alpha) args1 args2
| Eif (a1,b1,c1), Eif(a2,b2,c2) ->
aux alpha a1 a2 && aux alpha b1 b2 && aux alpha c1 c2
| Ematch (e1,es1,ty1), Ematch(e2,es2,ty2) ->
for_type env ty1 ty2
&& List.all2 (aux alpha) (e1::es1) (e2::es2)
| _, _ -> false
in fun alpha e1 e2 -> aux alpha e1 e2
(* ------------------------------------------------------------------ *)
let for_lv env _alpha ~norm lv1 lv2 =
match lv1, lv2 with
| LvVar(p1, _), LvVar(p2, _) ->
for_pv env ~norm p1 p2
| LvTuple p1, LvTuple p2 ->
List.all2
(fun (p1, _) (p2, _) -> for_pv env ~norm p1 p2)
p1 p2
| _, _ -> false
(* ------------------------------------------------------------------ *)
let rec for_stmt env alpha ~norm s1 s2 =
s_equal s1 s2
|| List.all2 (for_instr env alpha ~norm) s1.s_node s2.s_node
(* ------------------------------------------------------------------ *)
and for_instr env alpha ~norm i1 i2 =
i_equal i1 i2 || for_instr_r env alpha ~norm i1 i2
and for_instr_r env alpha ~norm i1 i2 =
match i1.i_node, i2.i_node with
| Sasgn (lv1, e1), Sasgn (lv2, e2) ->
for_lv env alpha ~norm lv1 lv2
&& for_expr env alpha ~norm e1 e2
| Srnd (lv1, e1), Srnd (lv2, e2) ->
for_lv env alpha ~norm lv1 lv2
&& for_expr env alpha ~norm e1 e2
| Scall (lv1, f1, e1), Scall (lv2, f2, e2) ->
oall2 (for_lv env alpha ~norm) lv1 lv2
&& for_xp env ~norm f1 f2
&& List.all2 (for_expr env alpha ~norm) e1 e2
| Sif (a1, b1, c1), Sif(a2, b2, c2) ->
for_expr env alpha ~norm a1 a2
&& for_stmt env alpha ~norm b1 b2
&& for_stmt env alpha ~norm c1 c2
| Swhile(a1,b1), Swhile(a2,b2) ->
for_expr env alpha ~norm a1 a2
&& for_stmt env alpha ~norm b1 b2
| Smatch(e1,bs1), Smatch(e2,bs2)
when List.length bs1 = List.length bs2
-> begin
let module E = struct exception NotConv end in
let check_branch (xs1, s1) (xs2, s2) =
if List.length xs1 <> List.length xs2 then
raise E.NotConv;
let alpha =
let rec do1 alpha (id1, ty1) (id2, ty2) =
if not (for_type env ty1 ty2) then
raise E.NotConv;
Mid.add id1 id2 alpha in
List.fold_left2 do1 alpha xs1 xs2
in for_stmt env alpha ~norm s1 s2 in
try
for_expr env alpha ~norm e1 e2
&& List.all2 (check_branch) bs1 bs2
with E.NotConv -> false
end
| Sassert a1, Sassert a2 ->
for_expr env alpha ~norm a1 a2
| Sabstract id1, Sabstract id2 ->
EcIdent.id_equal id1 id2
| _, _ -> false
(* ------------------------------------------------------------------ *)
let for_pv = fun env ?(norm = true) -> for_pv env ~norm
let for_xp = fun env ?(norm = true) -> for_xp env ~norm
let for_mp = fun env ?(norm = true) -> for_mp env ~norm
let for_instr = fun env ?(norm = true) -> for_instr env Mid.empty ~norm
let for_stmt = fun env ?(norm = true) -> for_stmt env Mid.empty ~norm
let for_expr = fun env ?(norm = true) -> for_expr env Mid.empty ~norm
end
(* -------------------------------------------------------------------- *)
type reduction_info = {
beta : bool;
delta_p : (path -> bool);
delta_h : (ident -> bool);
zeta : bool;
iota : bool;
eta : bool;
logic : rlogic_info;
modpath : bool;
user : bool;
cost : bool;
}
and rlogic_info = [`Full | `ProductCompat] option
(* -------------------------------------------------------------------- *)
let full_red = {
beta = true;
delta_p = EcUtils.predT;
delta_h = EcUtils.predT;
zeta = true;
iota = true;
eta = true;
logic = Some `Full;
modpath = true;
user = true;
cost = true;
}
let no_red = {
beta = false;
delta_p = EcUtils.pred0;
delta_h = EcUtils.pred0;
zeta = false;
iota = false;
eta = false;
logic = None;
modpath = false;
user = false;
cost = false;
}
let beta_red = { no_red with beta = true; }
let betaiota_red = { no_red with beta = true; iota = true; }
let nodelta =
{ full_red with
delta_h = EcUtils.pred0;
delta_p = EcUtils.pred0; }
let delta = { no_red with delta_p = EcUtils.predT; }
let reduce_local ri hyps x =
if ri.delta_h x
then LDecl.unfold x hyps
else raise NotReducible
let reduce_op ri env p tys =
if ri.delta_p p
then Op.reduce env p tys
else raise NotReducible
let is_record env f =
match EcFol.destr_app f with
| { f_node = Fop (p, _) }, _ -> EcEnv.Op.is_record_ctor env p
| _ -> false
(* -------------------------------------------------------------------- *)
let reduce_match env (f, bs, ty) =
let op, args = destr_app f in
match op.f_node with
| Fop (p, _) when EcEnv.Op.is_dtype_ctor env p ->
let idx = EcEnv.Op.by_path p env in
let idx = snd (EcDecl.operator_as_ctor idx) in
let br = oget (List.nth_opt bs idx) in
f_app br args ty
| _ -> raise NotReducible
(* -------------------------------------------------------------------- *)
type mode =
| UR_Form
| UR_CostPre of EcMemory.memory
| UR_CostExpr of EcMemory.memory
let is_UR_CostExpr = function UR_CostExpr _ -> true | _ -> false
let get_UR_CostExpr = function UR_CostExpr m -> m | _ -> assert false
(* -------------------------------------------------------------------- *)
let rec h_red_x ri env hyps f =
match f.f_node with
(* β-reduction *)
| Fapp ({ f_node = Fquant (Llambda, _, _)}, _) when ri.beta ->
f_betared f
(* ζ-reduction *)
| Flocal x -> reduce_local ri hyps x
(* ζ-reduction *)
| Fapp ({ f_node = Flocal x }, args) ->
f_app_simpl (reduce_local ri hyps x) args f.f_ty
(* ζ-reduction *)
| Flet (LSymbol(x,_), e1, e2) when ri.zeta ->
let s = Fsubst.f_bind_local Fsubst.f_subst_id x e1 in
Fsubst.f_subst s e2
(* ι-reduction (let-tuple) *)
| Flet (LTuple ids, { f_node = Ftuple es }, e2) when ri.iota ->
let s =
List.fold_left2
(fun s (x,_) e1 -> Fsubst.f_bind_local s x e1)
Fsubst.f_subst_id ids es
in
Fsubst.f_subst s e2
(* ι-reduction (let-records) *)
| Flet (LRecord (_, ids), f1, f2) when ri.iota && is_record env f1 ->
let args = snd (EcFol.destr_app f1) in
let subst =
List.fold_left2 (fun subst (x, _) e ->
match x with
| None -> subst
| Some x -> Fsubst.f_bind_local subst x e)
Fsubst.f_subst_id ids args
in
Fsubst.f_subst subst f2
(* ι-reduction (records projection) *)
| Fapp ({ f_node = Fop (p, _); } as f1, args)
when ri.iota && EcEnv.Op.is_projection env p -> begin
try
match args with
| mk :: args -> begin
match (odfl mk (h_red_opt ri env hyps mk)).f_node with
| Fapp ({ f_node = Fop (mkp, _) }, mkargs) ->
if not (EcEnv.Op.is_record_ctor env mkp) then
raise NotReducible;
let v = oget (EcEnv.Op.by_path_opt p env) in
let v = proj3_2 (EcDecl.operator_as_proj v) in
let v = List.nth mkargs v in
f_app (odfl v (h_red_opt ri env hyps v)) args f.f_ty
| _ -> raise NotReducible
end
| _ -> raise NotReducible
with NotReducible ->
f_app (h_red_x ri env hyps f1) args f.f_ty
end
(* ι-reduction (tuples projection) *)
| Fproj(f1, i) when ri.iota ->
let f' = f_proj_simpl f1 i f.f_ty in
if f_equal f f' then f_proj (h_red_x ri env hyps f1) i f.f_ty else f'
(* ι-reduction (if-then-else) *)
| Fif (f1, f2, f3) when ri.iota ->
let f' = f_if_simpl f1 f2 f3 in
if f_equal f f' then f_if (h_red_x ri env hyps f1) f2 f3 else f'
(* ι-reduction (if-then-else) *)
| Fmatch (cf, bs, ty) when ri.iota -> begin
try
let f' = reduce_match env (cf, bs, ty) in
if f_equal f f' then raise NotReducible else f'
with NotReducible -> f_match (h_red_x ri env hyps cf) bs ty
end
(* ι-reduction (match-fix) *)
| Fapp ({ f_node = Fop (p, tys); } as f1, fargs)
when ri.iota && EcEnv.Op.is_fix_def env p -> begin
try
let op = oget (EcEnv.Op.by_path_opt p env) in
let fix = EcDecl.operator_as_fix op in
if List.length fargs < snd (fix.EcDecl.opf_struct) then
raise NotReducible;
let fargs, eargs = List.split_at (snd (fix.EcDecl.opf_struct)) fargs in
let args = Array.of_list fargs in
let pargs = List.fold_left (fun (opb, acc) v ->
let v = args.(v) in
let v = odfl v (h_red_opt ri env hyps v) in
match fst_map (fun x -> x.f_node) (EcFol.destr_app v) with
| (Fop (p, _), cargs) when EcEnv.Op.is_dtype_ctor env p -> begin
let idx = EcEnv.Op.by_path p env in
let idx = snd (EcDecl.operator_as_ctor idx) in
match opb with
| EcDecl.OPB_Leaf _ -> assert false
| EcDecl.OPB_Branch bs ->
((Parray.get bs idx).EcDecl.opb_sub, cargs :: acc)
end
| _ -> raise NotReducible)
(fix.EcDecl.opf_branches, []) (fst fix.EcDecl.opf_struct)
in
let pargs, (bds, body) =
match pargs with
| EcDecl.OPB_Leaf (bds, body), cargs -> (List.rev cargs, (bds, body))
| _ -> assert false
in
let subst =
List.fold_left2
(fun subst (x, _) fa -> Fsubst.f_bind_local subst x fa)
Fsubst.f_subst_id fix.EcDecl.opf_args fargs in
let subst =
List.fold_left2
(fun subst bds cargs ->
List.fold_left2
(fun subst (x, _) fa -> Fsubst.f_bind_local subst x fa)
subst bds cargs)
subst bds pargs in
let body = EcFol.form_of_expr EcFol.mhr body in
let body =
EcFol.Fsubst.subst_tvar
(EcTypes.Tvar.init (List.map fst op.EcDecl.op_tparams) tys) body in
f_app (Fsubst.f_subst subst body) eargs f.f_ty
with NotReducible ->
f_app (h_red_x ri env hyps f1) fargs f.f_ty
end
(* μ-reduction *)
| Fglob (mp, m) when ri.modpath ->
let f' = EcEnv.NormMp.norm_glob env m mp in
if f_equal f f' then raise NotReducible else f'
(* μ-reduction *)
| Fpvar (pv, m) when ri.modpath ->
let pv' = EcEnv.NormMp.norm_pvar env pv in
if pv_equal pv pv' then raise NotReducible else f_pvar pv' f.f_ty m
(* η-reduction *)
| Fquant (Llambda, [x, GTty _], { f_node = Fapp (fn, args) })
when ri.eta && can_eta x (fn, args)
-> f_app fn (List.take (List.length args - 1) args) f.f_ty
| Fcoe c when ri.cost && EcCHoare.free_expr c.coe_e -> f_i0
| _ ->
let strategies =
[ reduce_logic;
reduce_user ~mode:`BeforeDelta;
reduce_delta;
reduce_user ~mode:`AfterDelta ;
reduce_context]
in
oget ~exn:NotReducible (List.Exceptionless.find_map
(fun strategy ->
try Some (strategy ri env hyps f) with NotReducible -> None)
strategies)
and reduce_logic ri env hyps f =
match f.f_node with
| Fapp ({f_node = Fop (p, tys); } as fo, args)
when is_some ri.logic && is_logical_op p
->
let pcompat =
match oget ri.logic with `Full -> true | `ProductCompat -> false
in
let f' =
match op_kind p, args with
| Some (`Not), [f1] when pcompat -> f_not_simpl f1
| Some (`Imp), [f1;f2] when pcompat -> f_imp_simpl f1 f2
| Some (`Iff), [f1;f2] when pcompat -> f_iff_simpl f1 f2
| Some (`And `Asym), [f1;f2] -> f_anda_simpl f1 f2
| Some (`Or `Asym), [f1;f2] -> f_ora_simpl f1 f2
| Some (`And `Sym ), [f1;f2] -> f_and_simpl f1 f2
| Some (`Or `Sym ), [f1;f2] -> f_or_simpl f1 f2
| Some (`Int_le ), [f1;f2] -> f_int_le_simpl f1 f2
| Some (`Int_lt ), [f1;f2] -> f_int_lt_simpl f1 f2
| Some (`Real_le ), [f1;f2] -> f_real_le_simpl f1 f2
| Some (`Real_lt ), [f1;f2] -> f_real_lt_simpl f1 f2
| Some (`Int_add ), [f1;f2] -> f_int_add_simpl f1 f2
| Some (`Int_opp ), [f] -> f_int_opp_simpl f
| Some (`Int_mul ), [f1;f2] -> f_int_mul_simpl f1 f2
| Some (`Int_edivz), [f1;f2] -> f_int_edivz_simpl f1 f2
| Some (`Real_add ), [f1;f2] -> f_real_add_simpl f1 f2
| Some (`Real_opp ), [f] -> f_real_opp_simpl f
| Some (`Real_mul ), [f1;f2] -> f_real_mul_simpl f1 f2
| Some (`Real_inv ), [f] -> f_real_inv_simpl f
| Some (`Eq ), [f1;f2] -> begin
match fst_map f_node (destr_app f1), fst_map f_node (destr_app f2) with
| (Fop (p1, _), args1), (Fop (p2, _), args2)
when EcEnv.Op.is_dtype_ctor env p1
&& EcEnv.Op.is_dtype_ctor env p2 ->
let idx p =
let idx = EcEnv.Op.by_path p env in
snd (EcDecl.operator_as_ctor idx)
in
if idx p1 <> idx p2
then f_false
else f_ands (List.map2 f_eq args1 args2)
| (_, []), (_, [])
when EqTest.for_type env f1.f_ty EcTypes.tunit
&& EqTest.for_type env f2.f_ty EcTypes.tunit ->
f_true
| _ ->
if f_equal f1 f2 || is_alpha_eq hyps f1 f2
then f_true
else f_eq_simpl f1 f2
end
| _ when ri.delta_p p ->
let op = reduce_op ri env p tys in
f_app_simpl op args f.f_ty
| _ -> f
in
if f_equal f f'
then f_app fo (h_red_args ri env hyps args) f.f_ty
else f'
| Fcoe ({ coe_e = { e_node = Etuple es } } as coe) when ri.cost ->
List.fold_left (fun acc e ->
f_xadd acc (EcCHoare.cost_of_expr coe.coe_pre coe.coe_mem e))
f_x1 es
| Fcoe ({ coe_e = {e_node = Eop (p, _)}} )
when EcEnv.Op.is_dtype_ctor env p && ri.cost ->
(* FIXME: check the number of arguments *)
f_x1
| Fcoe ({ coe_e = { e_node = Eapp ({e_node = Eop (p, _); }, es) }} as coe)
when EcEnv.Op.is_dtype_ctor env p && ri.cost ->
(* FIXME: check the number of arguments *)
List.fold_left (fun acc e ->
f_xadd acc (EcCHoare.cost_of_expr coe.coe_pre coe.coe_mem e))
f_x1 es
| Fcoe ({ coe_e = { e_node = Eif (c,l,r) }} as coe) when ri.cost ->
(* Max upper-bounded by the sum. *)
List.fold_left (fun acc e ->
f_xadd acc (EcCHoare.cost_of_expr coe.coe_pre coe.coe_mem e))
f_x1 [c; l; r]
| Fcoe ({ coe_e = { e_node = Eproj (e,_) }} as coe) when ri.cost ->
f_xadd f_x1 (EcCHoare.cost_of_expr coe.coe_pre coe.coe_mem e)
| _ -> raise NotReducible
and reduce_delta ri env _hyps f =
match f.f_node with
| Fop (p, tys) when ri.delta_p p ->
reduce_op ri env p tys
| Fapp ({ f_node = Fop (p, tys) }, args) when ri.delta_p p ->
let op = reduce_op ri env p tys in
f_app_simpl op args f.f_ty
| _ -> raise NotReducible
and reduce_context ri env hyps f =
match f.f_node with
(* contextual rule - let *)
| Flet (lp, f1, f2) -> f_let lp (h_red_x ri env hyps f1) f2
(* Contextual rule - application args. *)
| Fapp (f1, args) ->
f_app (h_red_x ri env hyps f1) args f.f_ty
(* Contextual rule - bindings *)
| Fquant (Lforall as t, b, f1)
| Fquant (Lexists as t, b, f1) -> begin
let ctor =
match t, ri.logic with
| Lforall, Some `Full -> f_forall_simpl
| Lforall, _ -> f_forall
| Lexists, Some `Full -> f_exists_simpl
| Lexists, _ -> f_exists
| Llambda, _ -> assert false in
try
let env = Mod.add_mod_binding b env in
ctor b (h_red_x ri env hyps f1)
with NotReducible ->
let f' = ctor b f1 in
if f_equal f f' then raise NotReducible else f'
end
| _ -> raise NotReducible
and reduce_user_gen mode simplify ri env hyps f
=
if not ri.user then raise NotReducible;
let p =
match f_node (fst (destr_app f)) with
| Fop (p, _) -> `Path p
| Ftuple _ -> `Tuple
| _ -> match f.f_node with
| Fcoe coe ->
let inner =
match (fst (EcTypes.destr_app coe.coe_e)).e_node with
| Eop (p, _) -> `Path p
| Etuple _ -> `Tuple
| _ -> raise NotReducible in
`Cost inner
| _ -> raise NotReducible in
let rules = EcEnv.Reduction.get p env in
let module R = EcTheory in
oget ~exn:NotReducible (List.Exceptionless.find_map (fun rule ->
begin
match mode, rule.R.rl_prio with
| `AfterDelta , n when n < 0 -> raise NotReducible
| `BeforeDelta, n when n >= 0 -> raise NotReducible
| ((`All | `BeforeDelta | `AfterDelta), _) -> ()
end;
try
let ue = EcUnify.UniEnv.create None in
let tvi = EcUnify.UniEnv.opentvi ue rule.R.rl_tyd None in
(* for formula varibales *)
let pv = ref (Mid.empty : form Mid.t) in
let check_pv x f =
match Mid.find_opt x !pv with
| None -> pv := Mid.add x f !pv
| Some f' -> check_alpha_equal ri hyps f f' in
(* for expression variables in schemata *)
let e_pv = ref (Mid.empty : expr Mid.t) in
let check_e_pv mhr x f =
try
match Mid.find_opt x !e_pv with
| None -> e_pv := Mid.add x (expr_of_form mhr f) !e_pv
(* must use mhr, c.f. caller of check_e_pv *)
| Some f' -> check_alpha_equal_e hyps (expr_of_form mhr f) f'
with CannotTranslate ->
Format.eprintf "[W]%a@."
(!EcEnv.pp_debug_form env) f;
raise CannotTranslate in (* idem *)
(* for memory pred. variables in schemata *)
let p_pv = ref (Mid.empty : mem_pr Mid.t) in
let check_p_pv m x f =
match Mid.find_opt x !p_pv with
| None -> p_pv := Mid.add x (m,f) !p_pv
| Some (m',f') ->
(* We freschen the memory. *)
(* FIXME: use inner function of check_alpha_equal *)
let mf = EcIdent.fresh m in
let fs = Fsubst.f_bind_mem Fsubst.f_subst_id m mf in
let fs' = Fsubst.f_bind_mem Fsubst.f_subst_id m' mf in
let f = Fsubst.f_subst fs f
and f' = Fsubst.f_subst fs' f' in
check_alpha_equal ri hyps f f' in
(* infered memtype, for schema application *)
let sc_mt = ref None in
let rec doit (mode : mode) f ptn =
match destr_app f, ptn with
| ({ f_node = Fop (p, tys) }, args), R.Rule (`Op (p', tys'), args')
when EcPath.p_equal p p' && List.length args = List.length args' ->
let tys' = List.map (EcTypes.Tvar.subst tvi) tys' in
begin
try List.iter2 (EcUnify.unify env ue) tys tys'
with EcUnify.UnificationFailure _ -> raise NotReducible end;
List.iter2 (doit mode) args args'
| ({ f_node = Ftuple args} , []), R.Rule (`Tuple, args')
when List.length args = List.length args' ->
List.iter2 (doit mode) args args'
| ({ f_node = Fint i }, []), R.Int j when EcBigInt.equal i j ->
()
| ({ f_node = Fcoe coe} , []), R.Cost (menv, inner_pre, inner_r) ->
if not ri.cost then
raise NotReducible;
(* Check memtype compatibility. *)
if EcMemory.is_schema (snd menv) then begin
if !sc_mt = None then
sc_mt := Some (snd coe.coe_mem)
else if not (EcMemory.mt_equal (snd coe.coe_mem) (oget !sc_mt))
then raise NotReducible
else () end
else
begin match
EcMemory.mt_equal_gen (fun ty1 ty2 ->
let ty2 = EcTypes.Tvar.subst tvi ty2 in
EcUnify.unify env ue ty1 ty2; true
) (snd coe.coe_mem) (snd menv)
with
| true -> ()
| false -> assert false
| exception (EcUnify.UnificationFailure _) -> raise NotReducible
end;
doit (UR_CostPre (fst coe.coe_mem)) coe.coe_pre inner_pre;
(* use mhr, to be consistent with check_e_pv *)
let mhr = fst coe.coe_mem in
let e = form_of_expr mhr coe.coe_e in
doit (UR_CostExpr mhr) e inner_r;
| _, R.Var x when mode = UR_Form ->
check_pv x f
| _, R.Var x when is_UR_CostExpr mode ->
let mhr = get_UR_CostExpr mode in
check_e_pv mhr x f
| _, R.Var x ->
let m = match mode with
| UR_CostPre m -> m
| _ -> assert false in
(* This case is more annoying. *)
if List.mem_assoc x rule.rl_vars
then check_pv x f
else if List.mem_assoc x rule.rl_evars
then check_e_pv m x f
else begin
assert (List.mem x rule.rl_pvars);
check_p_pv m x f end
| _ -> raise NotReducible in
doit UR_Form f rule.R.rl_ptn;
if not (EcUnify.UniEnv.closed ue) then
raise NotReducible;
let subst f =
let eus = EcUnify.UniEnv.assubst ue in
let tysubst = { ty_subst_id with ts_u = eus } in
if (Mid.is_empty !e_pv) && (Mid.is_empty !p_pv)
then (* axiom case *)
let subst = Fsubst.f_subst_init ~sty:tysubst () in
let subst =
Mid.fold (fun x f s -> Fsubst.f_bind_local s x f) !pv subst in
Fsubst.f_subst subst (Fsubst.subst_tvar tvi f)
else (* schema case, which is more complicated *)
let typ =
List.map (fun (a, _) -> Mid.find a tvi) rule.R.rl_tyd in
let typ = List.map (EcTypes.ty_subst tysubst) typ in
let es = List.map (fun (a,_ty) ->
let e = Mid.find a !e_pv in
e
) rule.R.rl_evars in
let mt = oget ~exn:NotReducible !sc_mt in
let ps = List.map (fun id ->
Mid.find id !p_pv
) rule.R.rl_pvars in
let f =
EcDecl.sc_instantiate
rule.R.rl_tyd rule.R.rl_pvars rule.R.rl_evars
typ mt ps es f in
let subst =
Mid.fold (fun x f s ->
Fsubst.f_bind_local s x f
) !pv (Fsubst.f_subst_init ()) in
Fsubst.f_subst subst (Fsubst.subst_tvar tvi f) in
List.iter (fun cond ->
if not (f_equal (simplify (subst cond)) f_true) then
raise NotReducible)
rule.R.rl_cond;
Some (subst rule.R.rl_tg)
with NotReducible -> None)
rules)
and reduce_user ~mode ri env hyps f =
reduce_user_gen mode (simplify ri env hyps) ri env hyps f
and can_eta x (f, args) =
match List.rev args with
| { f_node = Flocal y } :: args ->
let check v = not (Mid.mem x v.f_fv) in
id_equal x y && List.for_all check (f :: args)
| _ -> false
and h_red_args ri env hyps args =
match args with
| [] -> raise NotReducible
| a :: args ->
try h_red_x ri env hyps a :: args
with NotReducible -> a :: h_red_args ri env hyps args
and h_red_opt ri env hyps f =
try Some (h_red_x ri env hyps f)
with NotReducible -> None
and check_e ensure env s e1 e2 =
let es = e_subst_init s.fs_freshen s.fs_sty.ts_p
s.fs_ty Mp.empty s.fs_mp s.fs_esloc in
let e2 = EcTypes.e_subst es e2 in
ensure (EqTest.for_expr env e1 e2)
and check_alpha_equal_e hyps e1 e2 =
let env = LDecl.toenv hyps in
let exn = IncompatibleExpr (env, (e1, e2)) in
let error () = raise exn in
let ensure t = if not t then error () in
check_e ensure env Fsubst.f_subst_id e1 e2
and check_alpha_equal ri hyps f1 f2 =
let env = LDecl.toenv hyps in
let exn = IncompatibleForm (env, (f1, f2)) in
let error () = raise exn in
let ensure t = if not t then error () in
let check_ty env subst ty1 ty2 =
ensure (EqTest.for_type env ty1 (subst.fs_ty ty2)) in
let add_local (env, subst) (x1,ty1) (x2,ty2) =
check_ty env subst ty1 ty2;
env,
if id_equal x1 x2 then subst
else Fsubst.f_bind_rename subst x2 x1 ty1 in
let check_lpattern env subst lp1 lp2 =
match lp1, lp2 with
| LSymbol xt1, LSymbol xt2 -> add_local (env, subst) xt1 xt2
| LTuple lid1, LTuple lid2 when List.length lid1 = List.length lid2 ->
List.fold_left2 add_local (env,subst) lid1 lid2
| _, _ -> error() in
let check_memtype env mt1 mt2 =
ensure (EcMemory.mt_equal_gen (EqTest.for_type env) mt1 mt2) in
let check_local subst id1 f2 id2 =
match (Mid.find_def f2 id2 subst.fs_loc).f_node with
| Flocal id2 -> ensure (EcIdent.id_equal id1 id2)
| _ -> assert false in
let check_mem subst m1 m2 =
let m2 = Mid.find_def m2 m2 subst.fs_mem in
ensure (EcIdent.id_equal m1 m2) in
let check_pv env subst pv1 pv2 =
let pv2 = pv_subst (EcPath.x_substm subst.fs_sty.ts_p subst.fs_mp) pv2 in
ensure (EqTest.for_pv env pv1 pv2) in
let check_mp env subst mp1 mp2 =
let mp2 = EcPath.m_subst subst.fs_sty.ts_p subst.fs_mp mp2 in
ensure (EqTest.for_mp env mp1 mp2) in
let check_xp env subst xp1 xp2 =
let xp2 = EcPath.x_substm subst.fs_sty.ts_p subst.fs_mp xp2 in
ensure (EqTest.for_xp env xp1 xp2) in
let check_s env s s1 s2 =
let es = e_subst_init s.fs_freshen s.fs_sty.ts_p
s.fs_ty Mp.empty s.fs_mp s.fs_esloc in
let s2 = EcModules.s_subst es s2 in
ensure (EqTest.for_stmt env s1 s2) in
(* TODO all declaration in env, do it also in add local *)
let rec check_binding (env, subst) (x1,gty1) (x2,gty2) =
let gty2 = Fsubst.subst_gty subst gty2 in
match gty1, gty2 with
| GTty ty1, GTty ty2 ->
ensure (EqTest.for_type env ty1 ty2);
env,
if id_equal x1 x2 then subst else
Fsubst.f_bind_rename subst x2 x1 ty1
| GTmodty p1 , GTmodty p2 ->
let test f1 f2 = aux env subst f1 f2; true in
ensure (ModTy.mod_type_equiv test env p1 p2);
Mod.bind_local x1 p1 env,
if id_equal x1 x2 then subst
else Fsubst.f_bind_mod subst x2 (EcPath.mident x1)
| GTmem me1, GTmem me2 ->
check_memtype env me1 me2;
env,
if id_equal x1 x2 then subst
else Fsubst.f_bind_mem subst x2 x1
| _, _ -> error ()
and check_bindings env subst bd1 bd2 =
try List.fold_left2 check_binding (env,subst) bd1 bd2
with Invalid_argument _ -> error ()
and aux1 env subst f1 f2 =
if Fsubst.is_subst_id subst && f_equal f1 f2 then ()
else match f1.f_node, f2.f_node with
| Fquant(q1,bd1,f1'), Fquant(q2,bd2,f2') when
q1 = q2 && List.length bd1 = List.length bd2 ->
let env, subst = check_bindings env subst bd1 bd2 in
aux env subst f1' f2'
| Fif(a1,b1,c1), Fif(a2,b2,c2) ->
aux env subst a1 a2; aux env subst b1 b2; aux env subst c1 c2
| Fmatch(f1,bs1,ty1), Fmatch(f2,bs2,ty2) -> begin
aux env subst f1 f2;
ensure (EqTest.for_type env ty1 ty2);
try List.iter2 (aux env subst) bs1 bs2
with Invalid_argument _ -> error ()
end
| Flet(p1,f1',g1), Flet(p2,f2',g2) ->
aux env subst f1' f2';
let (env,subst) = check_lpattern env subst p1 p2 in
aux env subst g1 g2
| Fint i1, Fint i2 when EcBigInt.equal i1 i2 -> ()
| Flocal id1, Flocal id2 -> check_local subst id1 f2 id2
| Fpvar(p1,m1), Fpvar(p2,m2) ->
check_mem subst m1 m2;
check_pv env subst p1 p2
| Fglob(p1,m1), Fglob(p2,m2) ->
check_mem subst m1 m2;
check_mp env subst p1 p2
| Fop(p1, ty1), Fop(p2, ty2) when EcPath.p_equal p1 p2 ->
List.iter2 (check_ty env subst) ty1 ty2
| Fapp(f1',args1), Fapp(f2',args2) when
List.length args1 = List.length args2 ->
aux env subst f1' f2';
List.iter2 (aux env subst) args1 args2
| Ftuple args1, Ftuple args2 when List.length args1 = List.length args2 ->
List.iter2 (aux env subst) args1 args2
| Fproj(f1,i1), Fproj(f2,i2) when i1 = i2 ->
aux env subst f1 f2
| FhoareF hf1, FhoareF hf2 ->
check_xp env subst hf1.hf_f hf2.hf_f;
aux env subst hf1.hf_pr hf2.hf_pr;
aux env subst hf1.hf_po hf2.hf_po
| FhoareS hs1, FhoareS hs2 ->
check_s env subst hs1.hs_s hs2.hs_s;
(* FIXME should check the memenv *)
aux env subst hs1.hs_pr hs2.hs_pr;
aux env subst hs1.hs_po hs2.hs_po
| FcHoareF chf1, FcHoareF chf2 ->
check_xp env subst chf1.chf_f chf2.chf_f;
aux env subst chf1.chf_pr chf2.chf_pr;
aux env subst chf1.chf_po chf2.chf_po;
aux_cost env subst chf1.chf_co chf2.chf_co
| FcHoareS chs1, FcHoareS chs2 ->
check_s env subst chs1.chs_s chs2.chs_s;
(* FIXME should check the memenv *)
aux env subst chs1.chs_pr chs2.chs_pr;
aux env subst chs1.chs_po chs2.chs_po;
aux_cost env subst chs1.chs_co chs2.chs_co
| FbdHoareF hf1, FbdHoareF hf2 ->
ensure (hf1.bhf_cmp = hf2.bhf_cmp);
check_xp env subst hf1.bhf_f hf2.bhf_f;
aux env subst hf1.bhf_pr hf2.bhf_pr;
aux env subst hf1.bhf_po hf2.bhf_po;
aux env subst hf1.bhf_bd hf2.bhf_bd
| FbdHoareS hs1, FbdHoareS hs2 ->
ensure (hs1.bhs_cmp = hs2.bhs_cmp);
check_s env subst hs1.bhs_s hs2.bhs_s;
(* FIXME should check the memenv *)
aux env subst hs1.bhs_pr hs2.bhs_pr;
aux env subst hs1.bhs_po hs2.bhs_po;
aux env subst hs1.bhs_bd hs2.bhs_bd
| FequivF ef1, FequivF ef2 ->
check_xp env subst ef1.ef_fl ef2.ef_fl;
check_xp env subst ef1.ef_fr ef2.ef_fr;
aux env subst ef1.ef_pr ef2.ef_pr;
aux env subst ef1.ef_po ef2.ef_po
| FequivS es1, FequivS es2 ->
check_s env subst es1.es_sl es2.es_sl;
check_s env subst es1.es_sr es2.es_sr;
(* FIXME should check the memenv *)
aux env subst es1.es_pr es2.es_pr;
aux env subst es1.es_po es2.es_po
| FeagerF eg1, FeagerF eg2 ->
check_xp env subst eg1.eg_fl eg2.eg_fl;
check_xp env subst eg1.eg_fr eg2.eg_fr;
aux env subst eg1.eg_pr eg2.eg_pr;
aux env subst eg1.eg_po eg2.eg_po;
check_s env subst eg1.eg_sl eg2.eg_sl;
check_s env subst eg1.eg_sr eg2.eg_sr
| Fpr pr1, Fpr pr2 ->
check_mem subst pr1.pr_mem pr2.pr_mem;
check_xp env subst pr1.pr_fun pr2.pr_fun;
aux env subst pr1.pr_args pr2.pr_args;
aux env subst pr1.pr_event pr2.pr_event
| Fcoe coe1, Fcoe coe2 ->
check_e ensure env subst coe1.coe_e coe2.coe_e;
let bd1 = fst coe1.coe_mem, GTmem (snd coe1.coe_mem) in
let bd2 = fst coe2.coe_mem, GTmem (snd coe2.coe_mem) in
let env, subst = check_bindings env subst [bd1] [bd2] in
aux env subst coe1.coe_pre coe2.coe_pre;
| _, _ -> error ()
and aux env subst f1 f2 =
try aux1 env subst f1 f2
with e when e == exn ->
match h_red_opt ri env hyps f1 with
| Some f1 -> aux env subst f1 f2
| None ->
match h_red_opt ri env hyps f2 with
| Some f2 -> aux env subst f1 f2
| None when EqTest.for_type env f1.f_ty f2.f_ty -> begin
let ty, codom =
match f1.f_node, f2.f_node with
| Fquant (Llambda, (_, GTty ty) :: bd, f1'), _ ->
ty, toarrow (List.map (gty_as_ty |- snd) bd) f1'.f_ty
| _, Fquant(Llambda, (_, GTty ty) :: bd, f2') ->
ty, toarrow (List.map (gty_as_ty |- snd) bd) f2'.f_ty
| _, _ -> raise e
in
let x = f_local (EcIdent.create "_") ty in
let f1 = f_app_simpl f1 [x] codom in
let f2 = f_app_simpl f2 [x] codom in
aux env subst f1 f2
end
| _ -> raise e
and aux_cost env subst co1 co2 =
let calls1 =
EcPath.Mx.fold (fun f c calls ->
let f' = NormMp.norm_xfun env f in
EcPath.Mx.change (fun old -> assert (old = None); Some c) f' calls
) co1.c_calls EcPath.Mx.empty
and calls2 =
EcPath.Mx.fold (fun f c calls ->
let f' = EcPath.x_substm subst.fs_sty.ts_p subst.fs_mp f in
let f' = NormMp.norm_xfun env f' in
EcPath.Mx.change (fun old -> assert (old = None); Some c) f' calls
) co2.c_calls EcPath.Mx.empty in
aux env subst co1.c_self co2.c_self;
EcPath.Mx.fold2_union (fun _ a1 a2 _ -> match a1,a2 with
| None, None -> assert false
| None, Some _ | Some _, None -> error ()
| Some cb1, Some cb2 ->
aux env subst cb1.cb_cost cb2.cb_cost;
aux env subst cb1.cb_called cb2.cb_called
) calls1 calls2 ()
in aux env Fsubst.f_subst_id f1 f2
and check_alpha_eq f1 f2 = check_alpha_equal no_red f1 f2
and check_conv f1 f2 = check_alpha_equal full_red f1 f2
and is_alpha_eq hyps f1 f2 =
try check_alpha_eq hyps f1 f2; true
with _ -> false
and simplify ri env hyps f =
let f' = try h_red_x ri env hyps f with NotReducible -> f in
if f == f'
then simplify_rec ri env hyps f
else simplify ri env hyps f'
and simplify_rec ri env hyps f =
match f.f_node with
| Fapp ({ f_node = Fop _ } as fo, args) ->
let args' = List.map (simplify ri env hyps) args in
let app1 = (fo, args , f.f_ty) in
let app2 = (fo, args', f.f_ty) in
let f' = EcFol.FSmart.f_app (f, app1) app2 in
(try h_red_x ri env hyps f' with NotReducible -> f')
| FhoareF hf when ri.modpath ->
let hf_f = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) hf.hf_f in
f_map (fun ty -> ty) (simplify ri env hyps) (f_hoareF_r { hf with hf_f })
| FcHoareF hf when ri.modpath ->
let chf_f = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) hf.chf_f in
f_map (fun ty -> ty) (simplify ri env hyps) (f_cHoareF_r { hf with chf_f })
| FbdHoareF hf when ri.modpath ->
let bhf_f = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) hf.bhf_f in
f_map (fun ty -> ty) (simplify ri env hyps) (f_bdHoareF_r { hf with bhf_f })
| FequivF ef when ri.modpath ->
let ef_fl = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) ef.ef_fl in
let ef_fr = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) ef.ef_fr in
f_map (fun ty -> ty) (simplify ri env hyps) (f_equivF_r { ef with ef_fl; ef_fr; })
| FeagerF eg when ri.modpath ->
let eg_fl = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) eg.eg_fl in
let eg_fr = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) eg.eg_fr in
f_map (fun ty -> ty) (simplify ri env hyps) (f_eagerF_r { eg with eg_fl ; eg_fr; })
| Fpr pr when ri.modpath ->
let pr_fun = EcEnv.NormMp.norm_xfun (LDecl.toenv hyps) pr.pr_fun in
f_map (fun ty -> ty) (simplify ri env hyps) (f_pr_r { pr with pr_fun })
| _ -> f_map (fun ty -> ty) (simplify ri env hyps) f
(* -------------------------------------------------------------------- *)
let is_conv hyps f1 f2 =
try check_conv hyps f1 f2; true with _ -> false
let h_red ri hyps f =
h_red_x ri (LDecl.toenv hyps) hyps f
let h_red_opt ri hyps f =
h_red_opt ri (LDecl.toenv hyps) hyps f
let simplify ri hyps f =
simplify ri (LDecl.toenv hyps) hyps f
(* -------------------------------------------------------------------- *)
type xconv = [`Eq | `AlphaEq | `Conv]
let xconv (mode : xconv) hyps =
match mode with
| `Eq -> f_equal
| `AlphaEq -> is_alpha_eq hyps
| `Conv -> is_conv hyps
(* -------------------------------------------------------------------- *)
module User = struct
type options = EcTheory.rule_option
type error =
| MissingVarInLhs of EcIdent.t
| MissingEVarInLhs of EcIdent.t
| MissingTyVarInLhs of EcIdent.t
| MissingPVarInLhs of EcIdent.t
| NotAnEq
| NotFirstOrder
| RuleDependsOnMemOrModule
| HeadedByVar
exception InvalidUserRule of error
module R = EcTheory
type rule = EcEnv.Reduction.rule
let get_spec = function
| `Ax ax -> ax.EcDecl.ax_spec
| `Sc sc -> sc.EcDecl.axs_spec
let get_typ = function
| `Ax ax -> ax.EcDecl.ax_tparams
| `Sc sc -> sc.EcDecl.axs_tparams
type compile_st = { cst_ty_vs : Sid.t;
cst_f_vs : Sid.t;
cst_cost_pre_vs : Sid.t;
cst_cost_expr_vs : Sid.t; }
let empty_cst = { cst_ty_vs = Sid.empty;
cst_f_vs = Sid.empty;
cst_cost_pre_vs = Sid.empty;
cst_cost_expr_vs = Sid.empty; }
let compile ~opts ~prio (env : EcEnv.env) mode p =
let simp =
if opts.EcTheory.ur_delta then
let hyps = EcEnv.LDecl.init env [] in
fun f -> odfl f (h_red_opt delta hyps f)
else fun f -> f in
let ax_sc = match mode with
| `Ax -> `Ax (EcEnv.Ax.by_path p env)
| `Sc -> `Sc (EcEnv.Schema.by_path p env) in
let bds, rl = EcFol.decompose_forall (simp (get_spec ax_sc)) in
let bds =
let filter = function
| (x, GTty ty) -> (x, ty)
| _ -> raise (InvalidUserRule RuleDependsOnMemOrModule)
in List.map filter bds in
let pbds, ebds = match ax_sc with
| `Ax _ -> [],[]
| `Sc sc -> sc.EcDecl.axs_pparams, sc.EcDecl.axs_params in
let lhs, rhs, conds =
try
let rec doit conds f =
match sform_of_form (simp f) with
| SFimp (f1, f2) -> doit (f1 :: conds) f2
| SFeq (f1, f2) -> (f1, f2, List.rev conds)
| _ when ty_equal tbool (EcEnv.ty_hnorm f.f_ty env) ->
(f, f_true, List.rev conds)
| _ -> raise (InvalidUserRule NotAnEq)
in doit [] rl
with InvalidUserRule NotAnEq
when opts.EcTheory.ur_eqtrue &&
ty_equal tbool (EcEnv.ty_hnorm rl.f_ty env)
-> (rl, f_true, List.rev [])
in
let rule =
let rec rule (f : form) : EcTheory.rule_pattern =
match EcFol.destr_app f with
| { f_node = Fop (p, tys) }, args ->
R.Rule (`Op (p, tys), List.map rule args)
| { f_node = Ftuple args }, [] ->
R.Rule (`Tuple, List.map rule args)
| { f_node = Fint i }, [] ->
R.Int i
| { f_node = Flocal x }, [] ->
R.Var x
| { f_node = Fcoe coe }, [] ->
let inner_e = e_rule coe.coe_e in
let inner_pre = rule coe.coe_pre in
R.Cost (coe.coe_mem, inner_pre, inner_e)
| _ -> raise (InvalidUserRule NotFirstOrder)
and e_rule (e : expr) =
(* The chosen memory does not matter here (we pick [mhr] by default). *)
rule (form_of_expr mhr e)
in rule lhs in
let cst =
let rec doit ~cmode cst = function
| R.Var x ->
(* Depending on the mode, we add the variable to the corresp. set. *)
begin match cmode with
| UR_Form ->
{ cst with cst_f_vs = Sid.add x cst.cst_f_vs }
| UR_CostPre _ ->
{ cst with cst_cost_pre_vs = Sid.add x cst.cst_cost_pre_vs }
| UR_CostExpr _ ->
{ cst with cst_cost_expr_vs = Sid.add x cst.cst_cost_expr_vs } end
| R.Int _ -> cst
| R.Rule (op, args) ->
let ltyvars =
match op with
| `Op (_, tys) ->
List.fold_left (
let rec doit ltyvars = function
| { ty_node = Tvar a } -> Sid.add a ltyvars
| _ as ty -> ty_fold doit ltyvars ty in doit)
cst.cst_ty_vs tys
| `Tuple -> cst.cst_ty_vs in
let cst = {cst with cst_ty_vs = ltyvars } in
List.fold_left (doit ~cmode) cst args
| R.Cost (menv, pre, expr) ->
let mhr = fst menv in
let cst = doit ~cmode:(UR_CostExpr mhr) cst expr in
doit ~cmode:(UR_CostPre mhr) cst pre
in doit ~cmode:UR_Form empty_cst rule in
let s_bds = Sid.of_list (List.map fst bds)
and s_ebds = Sid.of_list (List.map fst ebds)
and s_pbds = Sid.of_list pbds
and s_tybds = Sid.of_list (List.map fst (get_typ ax_sc)) in
(* Variables appearing in types, cost expressions and formulas are
always, respectively, type, expression and formula variables. *)
let lvars = cst.cst_f_vs
and ltyvars = cst.cst_ty_vs
and levars = cst.cst_cost_expr_vs
and lpvars = Sid.empty in
(* Variables appearing in cost preconditions can be anything. *)
let lvars, levars, lpvars =
Sid.fold (fun id (lvars, levars, lpvars) ->
if Sid.mem id s_ebds
then (lvars, Sid.add id levars, lpvars)
else if Sid.mem id s_pbds
then (lvars, levars, Sid.add id lpvars)
else (Sid.add id lvars, levars, lpvars) (* default to formula a var *)
) cst.cst_cost_pre_vs (lvars, levars, lpvars) in
(* Sanity check *)
assert (Sid.disjoint lvars ltyvars &&
Sid.disjoint lvars levars &&
Sid.disjoint lvars lpvars &&
Sid.disjoint ltyvars levars &&
Sid.disjoint ltyvars lpvars &&
Sid.disjoint levars lpvars );
(* We check that the binded variables all appear in the lhs.
This ensures that, when applying the rule, we can infer how to
instantiate the axiom or schema by matching with the lhs. *)
let mvars = Sid.diff s_bds lvars in
let mevars = Sid.diff s_ebds levars in
let mtyvars = Sid.diff s_tybds ltyvars in
let mpvars = Sid.diff s_pbds lpvars in
if not (Sid.is_empty mvars) then
raise (InvalidUserRule (MissingVarInLhs (Sid.choose mvars)));
if not (Sid.is_empty mevars) then
raise (InvalidUserRule (MissingEVarInLhs (Sid.choose mevars)));
if not (Sid.is_empty mtyvars) then
raise (InvalidUserRule (MissingTyVarInLhs (Sid.choose mtyvars)));
if not (Sid.is_empty mpvars) then
raise (InvalidUserRule (MissingPVarInLhs (Sid.choose mpvars)));
begin match rule with
| R.Var _ -> raise (InvalidUserRule (HeadedByVar));
| _ -> () end;
R.{ rl_tyd = get_typ ax_sc;
rl_vars = bds;
rl_evars = ebds;
rl_pvars = pbds;
rl_cond = conds;
rl_ptn = rule;
rl_tg = rhs;
rl_prio = prio; }
end