base_samplers.ml
(*****************************************************************************)
(* *)
(* Open Source License *)
(* Copyright (c) 2020 Nomadic Labs <contact@tezos.com> *)
(* *)
(* Permission is hereby granted, free of charge, to any person obtaining a *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense, *)
(* and/or sell copies of the Software, and to permit persons to whom the *)
(* Software is furnished to do so, subject to the following conditions: *)
(* *)
(* The above copyright notice and this permission notice shall be included *)
(* in all copies or substantial portions of the Software. *)
(* *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER *)
(* DEALINGS IN THE SOFTWARE. *)
(* *)
(*****************************************************************************)
(* ------------------------------------------------------------------------- *)
(* Primitives for sampling basic data *)
type 'a sampler = Random.State.t -> 'a
(* range (inclusive) *)
type range = {min : int; max : int}
let range_encoding =
let open Data_encoding in
conv
(fun {min; max} -> (min, max))
(fun (min, max) -> {min; max})
(obj2 (req "min" int31) (req "max" int31))
let sample_in_interval ~range:{min; max} state =
if max - min >= 0 then min + Random.State.int state (max - min + 1)
else invalid_arg "Base_samplers.sample_in_interval"
let sample_float_in_interval ~min ~max state =
let diff = max -. min in
if diff > 0. then min +. Random.State.float state diff
else invalid_arg "Base_samplers.sample_float_in_interval"
let uniform_bool = Random.State.bool
let uniform_byte state = Char.chr (Random.State.int state 256)
let uniform_partial_byte ~nbits state =
if nbits < 1 || nbits > 8 then
invalid_arg "Base_samplers.uniform_partial_byte" ;
let i = Random.State.int state 256 in
Char.chr (i lsr (8 - nbits))
let uniform_string ~nbytes state =
String.init nbytes (fun _ -> uniform_byte state)
let uniform_bytes ~nbytes state =
Bytes.init nbytes (fun _ -> uniform_byte state)
let uniform_nat ~nbytes state = Z.of_bits (uniform_string state ~nbytes)
let uniform_int ~nbytes state =
let n = uniform_nat ~nbytes state in
if Random.State.bool state then Z.neg n else n
let nat ~size state =
let nbytes = sample_in_interval state ~range:size in
uniform_nat state ~nbytes
let int ~size state =
if size.min < 0 then invalid_arg "Base_samplers.int" ;
let nat = nat state ~size in
let s = Random.State.bool state in
if s then nat else Z.neg nat
let uniform_readable_ascii state =
(* Consult the ascii table for the meaning of this. *)
let i = Random.State.int state 96 in
if i = 95 then '\n' else Char.chr (32 + i)
let uniform_readable_ascii_string ~nbytes state =
String.init nbytes (fun _ -> uniform_readable_ascii state)
let readable_ascii_string ~size state =
if size.min < 0 then invalid_arg "Base_samplers.readable_ascii_string" ;
let nbytes = sample_in_interval state ~range:size in
uniform_readable_ascii_string ~nbytes state
let string ~size state =
if size.min < 0 then invalid_arg "Base_samplers.string" ;
let nbytes = sample_in_interval state ~range:size in
uniform_string state ~nbytes
let bytes ~size state =
if size.min < 0 then invalid_arg "Base_samplers.bytes" ;
let nbytes = sample_in_interval state ~range:size in
uniform_bytes state ~nbytes
(* ------------------------------------------------------------------------- *)
(* Sampling of "adversarial" values in the sense that they exhibit the
worst-case performance of COMPARE. *)
module Adversarial = struct
(* random string generator with a good probabiliy that sampling [n] times
will yield distinct results. *)
let salt_string state (n : int) : unit -> string =
if n <= 0 then Stdlib.failwith "salt_string: n <= 0" ;
let salt_length = 2 * Z.log2 (Z.of_int n) in
fun () -> uniform_string state ~nbytes:salt_length
(* random bytes generator with a good probabiliy that sampling [n] times
will yield distinct results. *)
let salt_bytes state (n : int) : unit -> bytes =
if n <= 0 then Stdlib.failwith "salt_bytes: n <= 0" ;
let salt_length = 2 * Z.log2 (Z.of_int n) in
fun () -> uniform_bytes state ~nbytes:salt_length
(* Adversarial Z.t *)
let integers ~prefix_size ~card state =
if card <= 0 then invalid_arg "Base_samplers.Adversarial.integers" ;
if prefix_size.min < 0 then invalid_arg "Base_samplers.Adversarial.integers" ;
let common_prefix = string state ~size:prefix_size in
let rand_suffix = salt_string state card in
let elements =
Stdlib.List.init card (fun _ ->
Z.of_bits (rand_suffix () ^ common_prefix))
in
(Z.of_bits common_prefix, elements)
(* Adversarial strings *)
let strings ~prefix_size ~card state =
if card <= 0 then invalid_arg "Base_samplers.Adversarial.strings" ;
if prefix_size.min < 0 then invalid_arg "Base_samplers.Adversarial.strings" ;
let common_prefix = string state ~size:prefix_size in
let rand_suffix = salt_string state card in
let elements =
List.init ~when_negative_length:() card (fun _ ->
common_prefix ^ rand_suffix ())
|> (* see [invalid_arg] above *)
WithExceptions.Result.get_ok ~loc:__LOC__
in
(common_prefix, elements)
(* Adversarial bytes *)
let bytes ~prefix_size ~card state =
if card <= 0 then invalid_arg "Base_samplers.Adversarial.bytes" ;
if prefix_size.min < 0 then invalid_arg "Base_samplers.Adversarial.bytes" ;
let common_prefix = bytes state ~size:prefix_size in
let rand_suffix = salt_bytes state card in
let elements =
List.init ~when_negative_length:() card (fun _ ->
Bytes.cat common_prefix (rand_suffix ()))
|> (* see [invalid_arg] above *)
WithExceptions.Result.get_ok ~loc:__LOC__
in
(common_prefix, elements)
end