swh:1:snp:9c27352633c4639a943e316050a7b904f57900e2
Raw File
Tip revision: d517a996606f39a8f226bc0b6697c97a496612bf authored by Pietro on 24 January 2018, 15:46:03 UTC
Test,p2p: Fix #98 : Occasionally test fail with an assert failure
Tip revision: d517a99
error_monad.ml
(**************************************************************************)
(*                                                                        *)
(*    Copyright (c) 2014 - 2017.                                          *)
(*    Dynamic Ledger Solutions, Inc. <contact@tezos.com>                  *)
(*                                                                        *)
(*    All rights reserved. No warranty, explicit or implicit, provided.   *)
(*                                                                        *)
(**************************************************************************)

(* Tezos Protocol Implementation - Error Monad *)

(*-- Error classification ----------------------------------------------------*)

type error_category = [ `Branch | `Temporary | `Permanent ]

type 'err full_error_category =
  [ error_category | `Wrapped of 'err -> error_category ]

(* HACK: forward reference from [Data_encoding_ezjsonm] *)
let json_to_string = ref (fun _ -> "")

let json_pp id encoding ppf x =
  Format.pp_print_string ppf @@
  !json_to_string @@
  let encoding =
    Data_encoding.(merge_objs (obj1 (req "id" string)) encoding) in
  Data_encoding.Json.construct encoding (id, x)

module Make() = struct

  type error = ..

  (* the toplevel store for error kinds *)
  type error_kind =
      Error_kind :
        { id: string ;
          from_error: error -> 'err option ;
          category: 'err full_error_category ;
          encoding_case: error Data_encoding.case ;
          pp: Format.formatter -> 'err -> unit ; } ->
      error_kind

  let error_kinds
    : error_kind list ref
    = ref []

  let error_encoding_cache = ref None

  let string_of_category = function
    | `Permanent -> "permanent"
    | `Temporary -> "temporary"
    | `Branch -> "branch"
    | `Wrapped _ -> "wrapped"
  let raw_register_error_kind
      category ~id:name ~title ~description ?pp
      encoding from_error to_error =
    if List.exists
        (fun (Error_kind { id ; _ }) -> name = id)
        !error_kinds then
      invalid_arg
        (Printf.sprintf
           "register_error_kind: duplicate error name: %s" name) ;
    if not (Data_encoding.is_obj encoding)
    then invalid_arg
        (Printf.sprintf
           "Specified encoding for \"%s\" is not an object, but error encodings must be objects."
           name) ;
    let encoding_case =
      let open Data_encoding in
      case Json_only
        (describe ~title ~description @@
         conv (fun x -> (((), ()), x)) (fun (((),()), x) -> x) @@
         merge_objs
           (obj2
              (req "kind" (constant (string_of_category category)))
              (req "id" (constant name)))
           encoding)
        from_error to_error in
    error_encoding_cache := None ;
    error_kinds :=
      Error_kind { id = name ;
                   category ;
                   from_error ;
                   encoding_case ;
                   pp = Option.unopt ~default:(json_pp name encoding) pp } :: !error_kinds

  let register_wrapped_error_kind
      category ~id ~title ~description ?pp
      encoding from_error to_error =
    raw_register_error_kind
      (`Wrapped category)
      ~id ~title ~description ?pp
      encoding from_error to_error

  let register_error_kind
      category ~id ~title ~description ?pp
      encoding from_error to_error =
    raw_register_error_kind
      (category :> _ full_error_category)
      ~id ~title ~description ?pp
      encoding from_error to_error

  let error_encoding () =
    match !error_encoding_cache with
    | None ->
        let cases =
          List.map
            (fun (Error_kind { encoding_case ; _ }) -> encoding_case )
            !error_kinds in
        let json_encoding = Data_encoding.union cases in
        let encoding =
          Data_encoding.dynamic_size @@
          Data_encoding.splitted
            ~json:json_encoding
            ~binary:
              (Data_encoding.conv
                 (Data_encoding.Json.construct json_encoding)
                 (Data_encoding.Json.destruct json_encoding)
                 Data_encoding.json) in
        error_encoding_cache := Some encoding ;
        encoding
    | Some encoding -> encoding

  let error_encoding = Data_encoding.delayed error_encoding

  let json_of_error error =
    Data_encoding.Json.construct error_encoding error
  let error_of_json json =
    Data_encoding.Json.destruct error_encoding json

  let classify_error error =
    let rec find e = function
      | [] -> `Temporary
      (* assert false (\* See "Generic error" *\) *)
      | Error_kind { from_error ; category ; _ } :: rest ->
          match from_error e with
          | Some x -> begin
              match category with
              | `Wrapped f -> f x
              | #error_category as x -> x
            end
          | None -> find e rest in
    find error !error_kinds

  let classify_errors errors =
    List.fold_left
      (fun r e -> match r, classify_error e with
         | `Permanent, _ | _, `Permanent -> `Permanent
         | `Branch, _ | _, `Branch -> `Branch
         | `Temporary, `Temporary -> `Temporary)
      `Temporary errors

  let pp ppf error =
    let rec find = function
      | [] -> assert false (* See "Generic error" *)
      | Error_kind { from_error ; pp ; _ } :: errors ->
          match from_error error with
          | None -> find errors
          | Some x -> pp ppf x in
    find !error_kinds

  (*-- Monad definition --------------------------------------------------------*)

  let (>>=) = Lwt.(>>=)

  type 'a tzresult = ('a, error list) result

  let result_encoding t_encoding =
    let open Data_encoding in
    let errors_encoding =
      describe ~title: "An erroneous result" @@
      obj1 (req "error" (list error_encoding)) in
    let t_encoding =
      describe ~title: "A successful result" @@
      obj1 (req "result" t_encoding) in
    union
      ~tag_size:`Uint8
      [ case (Tag 0) t_encoding
          (function Ok x -> Some x | _ -> None)
          (function res -> Ok res) ;
        case (Tag 1) errors_encoding
          (function Error x -> Some x | _ -> None)
          (fun errs -> Error errs) ]

  let return v = Lwt.return (Ok v)

  let error s = Error [ s ]

  let ok v = Ok v

  let fail s = Lwt.return (Error [ s ])

  let (>>?) v f =
    match v with
    | Error _ as err -> err
    | Ok v -> f v

  let (>>=?) v f =
    v >>= function
    | Error _ as err -> Lwt.return err
    | Ok v -> f v

  let (>>|?) v f = v >>=? fun v -> Lwt.return (Ok (f v))
  let (>|=) = Lwt.(>|=)

  let (>|?) v f = v >>? fun v -> Ok (f v)

  let rec map_s f l =
    match l with
    | [] -> return []
    | h :: t ->
        f h >>=? fun rh ->
        map_s f t >>=? fun rt ->
        return (rh :: rt)

  let mapi_s f l =
    let rec mapi_s f i l =
      match l with
      | [] -> return []
      | h :: t ->
          f i h >>=? fun rh ->
          mapi_s f (i+1) t >>=? fun rt ->
          return (rh :: rt)
    in
    mapi_s f 0 l

  let rec map_p f l =
    match l with
    | [] ->
        return []
    | x :: l ->
        let tx = f x and tl = map_p f l in
        tx >>= fun x ->
        tl >>= fun l ->
        match x, l with
        | Ok x, Ok l -> Lwt.return (Ok (x :: l))
        | Error exn1, Error exn2 -> Lwt.return (Error (exn1 @ exn2))
        | Ok _, Error exn
        | Error exn, Ok _ -> Lwt.return (Error exn)

  let mapi_p f l =
    let rec mapi_p f i l =
      match l with
      | [] ->
          return []
      | x :: l ->
          let tx = f i x and tl = mapi_p f (i+1) l in
          tx >>= fun x ->
          tl >>= fun l ->
          match x, l with
          | Ok x, Ok l -> Lwt.return (Ok (x :: l))
          | Error exn1, Error exn2 -> Lwt.return (Error (exn1 @ exn2))
          | Ok _, Error exn
          | Error exn, Ok _ -> Lwt.return (Error exn) in
    mapi_p f 0 l

  let rec map2_s f l1 l2 =
    match l1, l2 with
    | [], [] -> return []
    | _ :: _, [] | [], _ :: _ -> invalid_arg "Error_monad.map2_s"
    | h1 :: t1, h2 :: t2 ->
        f h1 h2 >>=? fun rh ->
        map2_s f t1 t2 >>=? fun rt ->
        return (rh :: rt)

  let rec map2 f l1 l2 =
    match l1, l2 with
    | [], [] -> Ok []
    | _ :: _, [] | [], _ :: _ -> invalid_arg "Error_monad.map2"
    | h1 :: t1, h2 :: t2 ->
        f h1 h2 >>? fun rh ->
        map2 f t1 t2 >>? fun rt ->
        Ok (rh :: rt)

  let rec filter_map_s f l =
    match l with
    | [] -> return []
    | h :: t ->
        f h >>=? function
        | None -> filter_map_s f t
        | Some rh ->
            filter_map_s f t >>=? fun rt ->
            return (rh :: rt)

  let filter_map_p f l =
    match l with
    | [] -> return []
    | h :: t ->
        let th = f h
        and tt = filter_map_s f t in
        th >>=? function
        | None -> tt
        | Some rh ->
            tt >>=? fun rt ->
            return (rh :: rt)

  let rec iter_s f l =
    match l with
    | [] -> return ()
    | h :: t ->
        f h >>=? fun () ->
        iter_s f t

  let rec iter_p f l =
    match l with
    | [] -> return ()
    | x :: l ->
        let tx = f x and tl = iter_p f l in
        tx >>= fun tx_res ->
        tl >>= fun tl_res ->
        match tx_res, tl_res with
        | Ok (), Ok () -> Lwt.return (Ok ())
        | Error exn1, Error exn2 -> Lwt.return (Error (exn1 @ exn2))
        | Ok (), Error exn
        | Error exn, Ok () -> Lwt.return (Error exn)

  let rec iter2_p f l1 l2 =
    match l1, l2 with
    | [], [] -> return ()
    | [], _ | _, [] -> invalid_arg "Error_monad.iter2_p"
    | x1 :: l1 , x2 :: l2 ->
        let tx = f x1 x2 and tl = iter2_p f l1 l2 in
        tx >>= fun tx_res ->
        tl >>= fun tl_res ->
        match tx_res, tl_res with
        | Ok (), Ok () -> Lwt.return (Ok ())
        | Error exn1, Error exn2 -> Lwt.return (Error (exn1 @ exn2))
        | Ok (), Error exn
        | Error exn, Ok () -> Lwt.return (Error exn)

  let iteri2_p f l1 l2 =
    let rec iteri2_p i f l1 l2 =
      match l1, l2 with
      | [], [] -> return ()
      | [], _ | _, [] -> invalid_arg "Error_monad.iteri2_p"
      | x1 :: l1 , x2 :: l2 ->
          let tx = f i x1 x2 and tl = iteri2_p (i+1) f l1 l2 in
          tx >>= fun tx_res ->
          tl >>= fun tl_res ->
          match tx_res, tl_res with
          | Ok (), Ok () -> Lwt.return (Ok ())
          | Error exn1, Error exn2 -> Lwt.return (Error (exn1 @ exn2))
          | Ok (), Error exn
          | Error exn, Ok () -> Lwt.return (Error exn)
    in
    iteri2_p 0 f l1 l2

  let rec fold_left_s f init l =
    match l with
    | [] -> return init
    | h :: t ->
        f init h >>=? fun acc ->
        fold_left_s f acc t

  let rec fold_right_s f l init =
    match l with
    | [] -> return init
    | h :: t ->
        fold_right_s f t init >>=? fun acc ->
        f h acc

  let rec join = function
    | [] -> return ()
    | t :: ts ->
        t >>= function
        | Error _ as err ->
            join ts >>=? fun () ->
            Lwt.return err
        | Ok () ->
            join ts

  let record_trace err result =
    match result with
    | Ok _ as res -> res
    | Error errs -> Error (err :: errs)

  let trace err f =
    f >>= function
    | Error errs -> Lwt.return (Error (err :: errs))
    | ok -> Lwt.return ok

  let fail_unless cond exn =
    if cond then return () else fail exn

  let fail_when cond exn =
    if cond then fail exn else return ()

  let unless cond f =
    if cond then return () else f ()

  let _when cond f =
    if cond then f () else return ()

  let pp_print_error ppf errors =
    match errors with
    | [] ->
        Format.fprintf ppf "Unknown error@."
    | [error] ->
        Format.fprintf ppf "@[<v 2>Error:@ %a@]@." pp error
    | errors ->
        Format.fprintf ppf "@[<v 2>Error, dumping error stack:@,%a@]@."
          (Format.pp_print_list pp)
          (List.rev errors)

  type error += Unclassified of string

  let () =
    let id = "" in
    let category = `Temporary in
    let to_error msg = Unclassified msg in
    let from_error = function
      | Unclassified msg -> Some msg
      | error ->
          let msg = Obj.(extension_name @@ extension_constructor error) in
          Some ("Unclassified error: " ^ msg ^ ". Was the error registered?") in
    let title = "Generic error" in
    let description =  "An unclassified error" in
    let encoding_case =
      let open Data_encoding in
      case Json_only
        (describe ~title ~description @@
         conv (fun x -> ((), x)) (fun ((), x) -> x) @@
         (obj2
            (req "kind" (constant "generic"))
            (req "error" string)))
        from_error to_error in
    let pp = Format.pp_print_string in
    error_kinds :=
      Error_kind { id ; from_error ; category ; encoding_case ; pp } :: !error_kinds

  type error += Assert_error of string * string

  let () =
    let id = "" in
    let category = `Permanent in
    let to_error (loc, msg) = Assert_error (loc, msg) in
    let from_error = function
      | Assert_error (loc, msg) -> Some (loc, msg)
      | _ -> None in
    let title = "Assertion error" in
    let description =  "An fatal assertion" in
    let encoding_case =
      let open Data_encoding in
      case Json_only
        (describe ~title ~description @@
         conv (fun (x, y) -> ((), x, y)) (fun ((), x, y) -> (x, y)) @@
         (obj3
            (req "kind" (constant "assertion"))
            (req "location" string)
            (req "error" string)))
        from_error to_error in
    let pp ppf (loc, msg) =
      Format.fprintf ppf
        "Assert failure (%s)%s"
        loc
        (if msg = "" then "." else ": " ^ msg) in
    error_kinds :=
      Error_kind { id; from_error ; category; encoding_case ; pp } :: !error_kinds

  let _assert b loc fmt =
    if b then
      Format.ikfprintf (fun _ -> return ()) Format.str_formatter fmt
    else
      Format.kasprintf (fun msg -> fail (Assert_error (loc, msg))) fmt


  let protect ~on_error t =
    t  >>= function
    | Ok res -> return res
    | Error err -> on_error err

end

include Make()

let generic_error fmt =
  Format.kasprintf (fun s -> error (Unclassified s)) fmt

let failwith fmt =
  Format.kasprintf (fun s -> fail (Unclassified s)) fmt

type error += Exn of exn
let error s = Error [ s ]
let error_exn s = Error [ Exn s ]
let trace_exn exn f = trace (Exn exn) f
let generic_trace fmt =
  Format.kasprintf (fun str -> trace_exn (Failure str)) fmt
let record_trace_exn exn f = record_trace (Exn exn) f

let failure fmt =
  Format.kasprintf (fun str -> Exn (Failure str)) fmt


let protect ?on_error t =
  Lwt.catch t (fun exn -> fail (Exn exn)) >>= function
  | Ok res -> return res
  | Error err ->
      match on_error with
      | Some f -> f err
      | None -> Lwt.return (Error err)

let pp_exn ppf exn = pp ppf (Exn exn)

let () =
  register_error_kind
    `Temporary
    ~id:"failure"
    ~title:"Generic error"
    ~description:"Unclassified error"
    ~pp:Format.pp_print_string
    Data_encoding.(obj1 (req "msg" string))
    (function
      | Exn (Failure msg) -> Some msg
      | Exn (Unix.Unix_error (err, fn, _)) ->
          Some ("Unix error in " ^ fn ^ ": " ^ Unix.error_message err)
      | Exn exn -> Some (Printexc.to_string exn)
      | _ -> None)
    (fun msg -> Exn (Failure msg))
back to top