Raw File
example_relabeling.v
(* monae: Monadic equational reasoning in Coq                                 *)
(* Copyright (C) 2020 monae authors, license: LGPL-2.1-or-later               *)
Require Import ZArith.
From mathcomp Require Import all_ssreflect.
From mathcomp Require boolp.
From infotheo Require Import ssrZ.
Require Import monae_lib hierarchy monad_lib fail_lib state_lib.

(******************************************************************************)
(*                           Tree relabeling                                  *)
(*                                                                            *)
(* see Sect. 9 of J. Gibbons, R. Hinze, Just do it: simple monadic equational *)
(* reasoning, ICFP 2011                                                       *)
(******************************************************************************)

Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.

Local Open Scope monae_scope.

Section Tree.
Variable A : Type.

Inductive Tree := Tip (a : A) | Bin of Tree & Tree.

Fixpoint foldt B (f : A -> B) (g : B * B -> B) (t : Tree) : B :=
  match t with
  | Tip a => f a
  | Bin t u => g (foldt f g t, foldt f g u)
  end.

(* TODO: move? *)
Section foldt_universal.
Variables (B : Type) (h : Tree -> B) (f : A -> B) (g : B * B -> B).
Hypothesis H1 : h \o Tip = f.
Hypothesis H2 : h \o uncurry Bin = g \o (fun x => (h x.1, h x.2)).
Lemma foldt_universal : h = foldt f g.
Proof.
rewrite boolp.funeqE; elim => [a|]; first by rewrite -H1.
by move=> t1 IH1 t2 IH2 /=; rewrite -IH1 -IH2 -(uncurryE Bin) -compE H2.
Qed.
End foldt_universal.

Definition size_Tree (t : Tree) := foldt (const 1) uaddn t.

Lemma size_Tree_Bin :
  size_Tree \o uncurry Bin = uaddn \o size_Tree^`2.
Proof. by rewrite boolp.funeqE; case. Qed.

Fixpoint labels (t : Tree) : seq A :=
  match t with
  | Tip a => [:: a]
  | Bin t u => labels t ++ labels u
  end.

End Tree.
Arguments Tip {A}.
Arguments Bin {A}.

Section tree_relabelling.

Variable Symbol : eqType. (* TODO: ideally, we would like a generic type here with a succ function *)
Variable M : failFreshMonad Symbol.
Variable q : pred (seq Symbol * seq Symbol).
Hypothesis promotable_q : promotable (Distinct M) q.

Local Open Scope mprog.

Definition relabel : Tree Symbol -> M (Tree Symbol) :=
  foldt (M # Tip \o const Fresh) (fmap (uncurry Bin) \o mpair).

Let drTip {A} : A -> M _ := (M # wrap) \o const Fresh.
Let drBin {N : failMonad} : (N _ * N _ -> N _) :=
  (fmap ucat) \o bassert q \o mpair.

(* extracting the distinct symbol list *)
Definition dlabels {N : failMonad} : Tree Symbol -> N (seq Symbol) :=
  foldt (Ret \o wrap) drBin.

Lemma dlabelsC t u (m : _ -> _ -> M (seq Symbol * seq Symbol)%type) :
  (do x <- dlabels t; do x0 <- relabel u; m x0 x =
   do x0 <- relabel u; do x <- dlabels t; m x0 x)%Do.
Proof.
elim: t u m => [a u /= m|t1 H1 t2 H2 u m].
  rewrite /dlabels /= bindretf; bind_ext => u'.
  by rewrite bindretf.
rewrite (_ : dlabels _ = drBin (dlabels t1, dlabels t2)) //.
rewrite [in RHS]/drBin [in RHS]/bassert 2!compE ![in RHS]bindA.
transitivity (do x0 <- relabel u;
  (do x <- dlabels t1;
   do x <- (do x1 <- (do y <- dlabels t2; Ret (x, y));
            (do x <- assert q x1; (Ret \o ucat) x));
   m x0 x))%Do; last first.
  bind_ext => u'; rewrite bind_fmap bindA; bind_ext => sS.
  rewrite 4!bindA; bind_ext => x; rewrite 2!bindretf !bindA.
  by rewrite_ bindretf.
rewrite -H1.
rewrite [in LHS]/drBin bind_fmap [in LHS]/bassert /= ![in LHS]bindA.
bind_ext => s.
rewrite !bindA.
transitivity (do x0 <- relabel u;
  (do x <- dlabels t2; (do x <-
    (do x1 <- Ret (s, x); (do x3 <- assert q x1; Ret (ucat x3)));
    m x0 x)))%Do; last by bind_ext => y2; rewrite bindA.
rewrite -H2.
bind_ext => s'.
rewrite !bindretf.
rewrite assertE.
rewrite bindA.
transitivity (guard (q (s, s')) >>
  (do x1 <- (Ret \o ucat) (s, s'); relabel u >>= (m^~ x1)))%Do.
  bind_ext; case; by rewrite 2!bindretf.
rewrite guardsC; last exact: failfresh_bindmfail.
rewrite !bindA !bindretf !bindA.
bind_ext => u'.
rewrite bindA guardsC; last exact: failfresh_bindmfail.
by rewrite bindA bindretf.
Qed.

(* see gibbons2011icfp Sect. 9.3 *)
Lemma join_and_pairs :
  (Join \o (M #(*TODO: fmap*) mpair) \o mpair) \o ((fmap dlabels) \o relabel)^`2 =
  (mpair \o Join^`2) \o            ((fmap dlabels) \o relabel)^`2.
Proof.
rewrite boolp.funeqE => -[x1 x2].
rewrite 3!compE.
rewrite joinE.
rewrite fmapE.
rewrite 2![in RHS]compE.
rewrite [in RHS]/mpair.
rewrite [in LHS]/mpair.
move H : (fmap dlabels) => h.
rewrite /=.
rewrite 2![in RHS]joinE.
rewrite 3!bindA.
rewrite -H.
rewrite !fmapE.
rewrite 3!bindA.
bind_ext => {}x1.
rewrite 2!bindretf 2!bindA.
do 3 rewrite_ bindretf.
rewrite -dlabelsC.
bind_ext => ?.
rewrite /=.
rewrite bindA.
bind_ext => ?.
by rewrite bindretf.
Qed.

(* first property of Sect. 9.3 *)
Lemma dlabels_relabel_is_fold :
  relabel >=> dlabels = foldt drTip drBin.
Proof.
apply foldt_universal.
  (* relabel >=> dlabels \o Tip = drTip *)
  rewrite /kleisli (* TODO(rei): don't unfold *) -(compA (Join \o _)) -(compA Join).
  rewrite (_ : _ \o Tip = (M # Tip) \o const Fresh) //.
  rewrite (compA (fmap dlabels)) -functor_o (_ : dlabels \o _ = Ret \o wrap) //.
  by rewrite functor_o 2!compA joinMret.
(* relabel >=> dlabels \o Bin = drBin \o _ *)
rewrite /kleisli (* TODO(rei): don't unfold *) -[in LHS](compA (Join \o _)) -[in LHS](compA Join).
rewrite (_ : _ \o _ Bin = (fmap (uncurry Bin)) \o (mpair \o relabel^`2)); last first.
  by rewrite boolp.funeqE; case.
rewrite (compA (fmap dlabels)) -functor_o.
rewrite (_ : _ \o _ Bin = (fmap ucat) \o bassert q \o mpair \o dlabels^`2); last first.
  by rewrite boolp.funeqE; case.
transitivity ((fmap ucat) \o Join \o (fmap (bassert q \o mpair)) \o mpair \o
    (fmap dlabels \o relabel)^`2).
  rewrite -2![in LHS](compA (fmap ucat)) [in LHS]functor_o.
  rewrite -[in LHS](compA (fmap _)) [in LHS](compA (Join \o _) (fmap _)).
  rewrite compfid natural -2![in RHS]compA; congr (_ \o _).
  by rewrite [in LHS]functor_o -[in LHS]compA naturality_mpair.
rewrite functor_o (compA _ (fmap (bassert q))) -(compA _ _ (fmap (bassert q))).
rewrite commutativity_of_assertions. (* first non-trivial step *)
rewrite (compA _ (bassert q)) -(compA _ _ (fmap mpair)) -(compA _ _ mpair) -(compA _ _ (_^`2)).
by rewrite join_and_pairs. (* second non-trivial step *)
Qed.

(* second property of Sect. 9.3 *)
Lemma symbols_size_is_fold :
  Symbols \o (@size_Tree Symbol) = foldt drTip drBin.
Proof.
apply foldt_universal.
  by rewrite -compA (_ : @size_Tree Symbol \o _ = const 1) // Symbols_prop1.
pose p := Distinct M.
transitivity (bassert p \o Symbols \o @size_Tree Symbol \o uncurry Bin
  : (_ -> M _)).
  by rewrite bassert_symbols.
transitivity ((bassert p) \o Symbols \o uaddn \o (@size_Tree Symbol)^`2
  : (_ -> M _)).
  by rewrite -[in LHS]compA -[in RHS]compA size_Tree_Bin.
transitivity (bassert p \o (fmap ucat) \o mpair \o (Symbols \o (@size_Tree Symbol))^`2
  : (_ -> M _)).
  rewrite -2!compA (compA Symbols) Symbols_prop2.
  by rewrite -(compA (_ \o mpair)) (compA (bassert p)).
transitivity ((fmap ucat) \o bassert q \o mpair \o (bassert p \o Symbols \o (@size_Tree Symbol))^`2
  : (_ -> M _)).
  (* assert p distributes over concatenation *)
  by rewrite (promote_assert_sufficient_condition (@failfresh_bindmfail _ M) promotable_q).
transitivity ((fmap ucat) \o bassert q \o mpair \o (Symbols \o (@size_Tree Symbol))^`2
  : (_ -> M _)).
  by rewrite bassert_symbols.
by [].
Qed.

(* main claim *)
Lemma dlabels_relabel_never_fails :
  relabel >=> dlabels = Symbols \o @size_Tree Symbol.
Proof. by rewrite dlabels_relabel_is_fold symbols_size_is_fold. Qed.

End tree_relabelling.
back to top