example_transformer.v
(* monae: Monadic equational reasoning in Coq *)
(* Copyright (C) 2020 monae authors, license: LGPL-2.1-or-later *)
From mathcomp Require Import all_ssreflect.
Require Import monae_lib hierarchy monad_lib fail_lib state_lib.
Require Import monad_transformer.
(******************************************************************************)
(* Examples of programs using monad transformers *)
(******************************************************************************)
Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.
Local Open Scope monae_scope.
(******************************************************************************)
(* reference: *)
(* - R. Affeldt, D. Nowak, Extending Equational Monadic Reasoning with Monad *)
(* Transformers, https://arxiv.org/abs/2011.03463 *)
(******************************************************************************)
Definition evalStateT (N : monad) (S : UU0) (M : stateRunMonad S N)
{A : UU0} (m : M A) (s : S) : N A :=
RunStateT m s >>= fun x => Ret x.1.
Section FastProduct.
Variables (N : exceptMonad) (M : exceptStateRunMonad nat N).
Fixpoint fastProductRec l : M unit :=
match l with
| [::] => Ret tt
| 0 :: _ => Fail
| n.+1 :: l' => Get >>= fun m => Put (m * n.+1) >> fastProductRec l'
end.
Definition fastProduct l : M _ :=
Catch (Put 1 >> fastProductRec l >> Get) (Ret 0 : M _).
Lemma fastProductCorrect l n :
evalStateT (fastProduct l) n = Ret (product l).
Proof.
rewrite /fastProduct -(mul1n (product _)); move: 1.
elim: l => [ | [ | x] l ih] m.
- rewrite muln1 bindA bindretf putget.
rewrite /evalStateT RunStateTCatch RunStateTBind RunStateTPut bindretf.
by rewrite RunStateTRet RunStateTRet catchret bindretf.
- rewrite muln0.
rewrite /evalStateT RunStateTCatch RunStateTBind RunStateTBind RunStateTPut.
by rewrite bindretf RunStateTFail bindfailf catchfailm RunStateTRet bindretf.
- rewrite [fastProductRec _]/=.
by rewrite -bindA putget bindA bindA bindretf -bindA -bindA putput ih mulnA.
Qed.
End FastProduct.
(* The following fail-state monad is such that it does not backtrack the
state. *)
Module PersistentState.
Section persistentstate.
Variable S : UU0.
Definition failPState (A : UU0) : UU0 :=
S -> option A * S.
Definition runFailPState {A : UU0} (m : failPState A) (s : S) : option A * S :=
m s.
Definition ret {A : UU0} (a : A) : failPState A :=
fun s => (Some a, s).
Definition bind {A B : UU0} (m : failPState A) (f : A -> failPState B) :
failPState B :=
fun s => match m s with
| (None, s') => (None, s')
| (Some a, s') => f a s'
end.
Lemma bindretf (A B : UU0) (a : A) (f : A -> failPState B) :
bind (ret a) f = f a.
Proof.
reflexivity.
Qed.
Lemma bindmret (A : UU0) (m : failPState A) :
bind m ret = m.
Proof.
rewrite boolp.funeqE => s.
unfold bind.
destruct (m s) as [[|]]; reflexivity.
Qed.
Lemma bindA
(A B C : UU0) (m : failPState A)
(f : A -> failPState B) (g : B -> failPState C) :
bind (bind m f) g = bind m (fun x => bind (f x) g).
Proof.
rewrite boolp.funeqE => s.
unfold bind.
destruct (m s) as [[|]]; reflexivity.
Qed.
Definition fail {A : UU0} : failPState A := fun s => (None, s).
Definition catch {A : UU0} (m1 m2 : failPState A) :=
fun s => match m1 s with
| (Some a, s') => (Some a, s')
| (None, s') => m2 s'
end.
Definition get : failPState S := fun s => (Some s, s).
Definition put (s : S) : failPState unit := fun _ => (Some tt, s).
End persistentstate.
Arguments runFailPState {_ _} _ _.
Arguments ret {_ _} _.
Arguments bind {_ _ _} _ _.
Arguments fail {_ _}.
Arguments get {_}.
Arguments put {_} _.
Local Notation "m >>> f" := (bind m (fun _ => f)) (at level 49).
(* The following example illustrates how the state is NOT backtracked when a
failure is catched. *)
Goal
runFailPState (put 1 >>> catch (put 2 >>> fail) get) 0 =
(Some 2, 2).
Proof.
reflexivity.
Qed.
End PersistentState.
Section incr_fail_incr.
Section with_failStateReifyMonad.
Variables M : failStateReifyMonad nat.
Let incr : M unit := Get >>= (Put \o succn).
Let prog (B : UU0) : M unit := incr >> @Fail _ B >> incr.
Goal forall T, prog T = @Fail _ _.
Proof.
move=> T; rewrite /prog.
rewrite bindA.
rewrite bindfailf.
by rewrite bindmfail.
Abort.
End with_failStateReifyMonad.
Section with_stateT_of_failMonad.
Variable N : failMonad.
Let M : monad := stateT nat N.
Let incr : M unit := Get >>= (Put \o succn).
Let prog T : M unit := incr >> Lift (stateT nat) N T Fail >> incr.
Goal forall T, prog T = Lift (stateT nat) N unit Fail.
Proof.
move=> T; rewrite /prog.
rewrite bindA.
rewrite bindLfailf. (* fail laws are not sufficient *)
Abort.
End with_stateT_of_failMonad.
Section with_exceptT_of_stateMonad.
Definition LGet S (M : stateMonad S) := Lift (exceptT unit) M S (@Get S M).
Definition LPut S (M : stateMonad S) := Lift (exceptT unit) M unit \o (@Put S M).
Variable N : stateMonad nat.
Let M : monad := exceptT unit N.
Let incr : M unit := LGet N >>= (LPut N \o succn).
Let prog T : M unit := incr >> (Fail : _ T) >> incr.
Goal forall T, prog T = @Fail _ unit.
Proof.
move=> T; rewrite /prog.
Abort.
End with_exceptT_of_stateMonad.
End incr_fail_incr.
Require Import monad_model.
Section incr_fail_incr_model.
Lemma bindLmfail (M := ModelMonad.option_monad) S T U (m : stateT S M U)
(FAIL := @ExceptOps.throw unit T tt) :
m >> Lift (stateT S) M T FAIL = Lift (stateT S) M T FAIL.
Proof.
rewrite -!liftSE /liftS boolp.funeqE => s.
rewrite /Bind /=.
rewrite /bindS /=.
rewrite /stateTmonad /=.
rewrite /Monad_of_ret_bind /=.
rewrite /Actm /=.
rewrite /Monad_of_ret_bind.Map /=.
rewrite /bindS /retS /=.
rewrite /Bind /=.
rewrite /ModelMonad.Except.bind /= /Actm /=.
rewrite /Monad_of_ret_bind.Map /=.
rewrite /ModelMonad.Except.bind /=.
by case: (m s); case.
Qed.
Section fail_model_sufficient.
Let N : failMonad := ModelFail.option.
Let M : monad := stateT nat N.
Let FAIL T := @ExceptOps.throw unit T tt.
Let incr : M unit := Get >>= (Put \o succn).
Let prog T : M unit := incr >> Lift (stateT nat) N T (@FAIL T) >> incr.
Goal forall T, prog T = Lift (stateT nat) N unit (@FAIL unit).
Proof.
move=> T; rewrite /prog.
rewrite bindLmfail.
by rewrite bindLfailf.
Abort.
End fail_model_sufficient.
End incr_fail_incr_model.