https://github.com/pymc-devs/pymc3
Tip revision: 5c600c7117c98d296244c16f27b2fcdb0d37cd70 authored by Ricardo Vieira on 26 July 2023, 09:31:30 UTC
Use graph_replace instead of clone_replace in VI
Use graph_replace instead of clone_replace in VI
Tip revision: 5c600c7
testing.py
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import itertools as it
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest
from numpy import random as nr
from numpy import testing as npt
from pytensor.compile.mode import Mode
from pytensor.graph.basic import Variable
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor import TensorVariable
from scipy import special as sp
from scipy import stats as st
import pymc as pm
from pymc.distributions.distribution import Distribution
from pymc.distributions.shape_utils import change_dist_size
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp
from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph
from pymc.pytensorf import (
compile_pymc,
floatX,
inputvars,
intX,
local_check_parameter_to_ninf_switch,
)
# This mode can be used for tests where model compilations takes the bulk of the runtime
# AND where we don't care about posterior numerical or sampling stability (e.g., when
# all that matters are the shape of the draws or deterministic values of observed data).
# DO NOT USE UNLESS YOU HAVE A GOOD REASON TO!
fast_unstable_sampling_mode = (
pytensor.compile.mode.FAST_COMPILE
# Remove slow rewrite phases
.excluding("canonicalize", "specialize")
# Include necessary rewrites for proper logp handling
.including("remove_TransformedVariables").register(
(in2out(local_check_parameter_to_ninf_switch), -1)
)
)
def product(domains, n_samples=-1):
"""Get an iterator over a product of domains.
Args:
domains: a dictionary of (name, object) pairs, where the objects
must be "domain-like", as in, have a `.vals` property
n_samples: int, maximum samples to return. -1 to return whole product
Returns
-------
list of the cartesian product of the domains
"""
try:
names, domains = zip(*domains.items())
except ValueError: # domains.items() is empty
return [{}]
all_vals = [zip(names, val) for val in it.product(*(d.vals for d in domains))]
if n_samples > 0 and len(all_vals) > n_samples:
return (all_vals[j] for j in nr.choice(len(all_vals), n_samples, replace=False))
return all_vals
class Domain:
def __init__(self, vals, dtype=pytensor.config.floatX, edges=None, shape=None):
# Infinity values must be kept as floats
vals = [np.array(v, dtype=dtype) if np.all(np.isfinite(v)) else floatX(v) for v in vals]
if edges is None:
edges = np.array(vals[0]), np.array(vals[-1])
vals = vals[1:-1]
else:
edges = list(edges)
if edges[0] is None:
edges[0] = np.full_like(vals[0], -np.inf)
if edges[1] is None:
edges[1] = np.full_like(vals[0], np.inf)
edges = tuple(edges)
if not vals:
raise ValueError(
f"Domain has no values left after removing edges: {edges}.\n"
"You can duplicate the edge values or explicitly specify the edges with the edge keyword.\n"
f"For example: `Domain([{edges[0]}, {edges[0]}, {edges[1]}, {edges[1]}])`"
)
if shape is None:
shape = vals[0].shape
self.vals = vals
self.shape = shape
self.lower, self.upper = edges
self.dtype = dtype
def __add__(self, other):
return Domain(
[v + other for v in self.vals],
self.dtype,
(self.lower + other, self.upper + other),
self.shape,
)
def __mul__(self, other):
try:
return Domain(
[v * other for v in self.vals],
self.dtype,
(self.lower * other, self.upper * other),
self.shape,
)
except TypeError:
return Domain(
[v * other for v in self.vals],
self.dtype,
(self.lower, self.upper),
self.shape,
)
def __neg__(self):
return Domain([-v for v in self.vals], self.dtype, (-self.lower, -self.upper), self.shape)
class ProductDomain:
def __init__(self, domains):
self.vals = list(it.product(*(d.vals for d in domains)))
self.shape = (len(domains),) + domains[0].shape
self.lower = [d.lower for d in domains]
self.upper = [d.upper for d in domains]
self.dtype = domains[0].dtype
def Vector(D, n):
return ProductDomain([D] * n)
def SortedVector(n):
vals = []
np.random.seed(42)
for _ in range(10):
vals.append(np.sort(np.random.randn(n)))
return Domain(vals, edges=(None, None))
def UnitSortedVector(n):
vals = []
np.random.seed(42)
for _ in range(10):
vals.append(np.sort(np.random.rand(n)))
return Domain(vals, edges=(None, None))
def RealMatrix(n, m):
vals = []
np.random.seed(42)
for _ in range(10):
vals.append(np.random.randn(n, m))
return Domain(vals, edges=(None, None))
def simplex_values(n):
if n == 1:
yield np.array([1.0])
else:
for v in Unit.vals:
for vals in simplex_values(n - 1):
yield np.concatenate([[v], (1 - v) * vals])
def Simplex(n):
return Domain(simplex_values(n), shape=(n,), dtype=Unit.dtype, edges=(None, None))
def MultiSimplex(n_dependent, n_independent):
vals = []
for simplex_value in it.product(simplex_values(n_dependent), repeat=n_independent):
vals.append(np.vstack(simplex_value))
return Domain(vals, dtype=Unit.dtype, shape=(n_independent, n_dependent))
def RandomPdMatrix(n):
A = np.random.rand(n, n)
return np.dot(A, A.T) + n * np.identity(n)
R = Domain([-np.inf, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.inf])
Rplus = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 100, np.inf])
Rplusbig = Domain([0, 0.5, 0.9, 0.99, 1, 1.5, 2, 20, np.inf])
Rminusbig = Domain([-np.inf, -2, -1.5, -1, -0.99, -0.9, -0.5, -0.01, 0])
Unit = Domain([0, 0.001, 0.1, 0.5, 0.75, 0.99, 1])
Circ = Domain([-np.pi, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.pi])
Runif = Domain([-np.inf, -0.4, 0, 0.4, np.inf])
Rdunif = Domain([-np.inf, -1, 0, 1, np.inf], "int64")
Rplusunif = Domain([0, 0.5, np.inf])
Rplusdunif = Domain([0, 10, np.inf], "int64")
I = Domain([-np.inf, -3, -2, -1, 0, 1, 2, 3, np.inf], "int64")
NatSmall = Domain([0, 3, 4, 5, np.inf], "int64")
Nat = Domain([0, 1, 2, 3, np.inf], "int64")
NatBig = Domain([0, 1, 2, 3, 5000, np.inf], "int64")
PosNat = Domain([1, 2, 3, np.inf], "int64")
Bool = Domain([0, 0, 1, 1], "int64")
def select_by_precision(float64, float32):
"""Helper function to choose reasonable decimal cutoffs for different floatX modes."""
decimal = float64 if pytensor.config.floatX == "float64" else float32
return decimal
def build_model(distfam, valuedomain, vardomains, extra_args=None):
if extra_args is None:
extra_args = {}
with pm.Model() as m:
param_vars = {}
for v, dom in vardomains.items():
v_pt = pytensor.shared(np.asarray(dom.vals[0]))
v_pt.name = v
param_vars[v] = v_pt
param_vars.update(extra_args)
distfam(
"value",
**param_vars,
transform=None,
)
return m, param_vars
def create_dist_from_paramdomains(
pymc_dist: Distribution,
paramdomains: Dict[str, Domain],
extra_args: Optional[Dict[str, Any]] = None,
) -> TensorVariable:
"""Create a PyMC distribution from a dictionary of parameter domains.
Returns
-------
PyMC distribution variable: TensorVariable
Value variable: TensorVariable
"""
if extra_args is None:
extra_args = {}
param_vars = {}
for param, domain in paramdomains.items():
param_type = pt.constant(np.asarray(domain.vals[0])).type()
param_type.name = param
param_vars[param] = param_type
return pymc_dist.dist(**param_vars, **extra_args)
def find_invalid_scalar_params(
paramdomains: Dict["str", Domain]
) -> Dict["str", Tuple[Union[None, float], Union[None, float]]]:
"""Find invalid parameter values from bounded scalar parameter domains.
For use in `check_logp`-like testing helpers.
Returns
-------
Invalid paramemeter values:
Dictionary mapping each parameter, to a lower and upper invalid values (out of domain).
If no lower or upper invalid values exist, None is returned for that entry.
"""
invalid_params = {}
for param, paramdomain in paramdomains.items():
lower_edge, upper_edge = None, None
if np.ndim(paramdomain.lower) == 0:
if np.isfinite(paramdomain.lower):
lower_edge = paramdomain.lower - 1
if np.isfinite(paramdomain.upper):
upper_edge = paramdomain.upper + 1
invalid_params[param] = (lower_edge, upper_edge)
return invalid_params
def check_logp(
pymc_dist: Distribution,
domain: Domain,
paramdomains: Dict[str, Domain],
scipy_logp: Callable,
decimal: Optional[int] = None,
n_samples: int = 100,
extra_args: Optional[Dict[str, Any]] = None,
scipy_args: Optional[Dict[str, Any]] = None,
skip_paramdomain_outside_edge_test: bool = False,
) -> None:
"""
Generic test for PyMC logp methods
Test PyMC logp and equivalent scipy logpmf/logpdf methods give similar
results for valid values and parameters inside the supported edges.
Edges are excluded by default, but can be artificially included by
creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
Parameters
----------
pymc_dist: PyMC distribution
domain : Domain
Supported domain of distribution values
paramdomains : Dictionary of Parameter : Domain pairs
Supported domains of distribution parameters
scipy_logp : Scipy logpmf/logpdf method
Scipy logp method of equivalent pymc_dist distribution
decimal : Int
Level of precision with which pymc_dist and scipy logp are compared.
Defaults to 6 for float64 and 3 for float32
n_samples : Int
Upper limit on the number of valid domain and value combinations that
are compared between pymc and scipy methods. If n_samples is below the
total number of combinations, a random subset is evaluated. Setting
n_samples = -1, will return all possible combinations. Defaults to 100
extra_args : Dictionary with extra arguments needed to build pymc model
Dictionary is passed to helper function `build_model` from which
the pymc distribution logp is calculated
scipy_args : Dictionary with extra arguments needed to call scipy logp method
Usually the same as extra_args
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
if scipy_args is None:
scipy_args = {}
def scipy_logp_with_scipy_args(**args):
args.update(scipy_args)
return scipy_logp(**args)
dist = create_dist_from_paramdomains(pymc_dist, paramdomains, extra_args)
value = dist.type()
value.name = "value"
pymc_dist_logp = logp(dist, value).sum()
pymc_logp = pytensor.function(list(inputvars(pymc_dist_logp)), pymc_dist_logp)
# Test supported value and parameters domain matches Scipy
domains = paramdomains.copy()
domains["value"] = domain
for point in product(domains, n_samples=n_samples):
point = dict(point)
npt.assert_almost_equal(
pymc_logp(**point),
scipy_logp_with_scipy_args(**point),
decimal=decimal,
err_msg=str(point),
)
valid_value = domain.vals[0]
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
valid_params["value"] = valid_value
# Test pymc distribution raises ParameterValueError for scalar parameters outside
# the supported domain edges (excluding edges)
if not skip_paramdomain_outside_edge_test:
invalid_params = find_invalid_scalar_params(paramdomains)
for invalid_param, invalid_edges in invalid_params.items():
for invalid_edge in invalid_edges:
if invalid_edge is None:
continue
point = valid_params.copy() # Shallow copy should be okay
point[invalid_param] = invalid_edge
with pytest.raises(ParameterValueError):
pymc_logp(**point)
pytest.fail(f"test_params={point}")
# Test that values outside of scalar domain support evaluate to -np.inf
invalid_values = find_invalid_scalar_params({"value": domain})["value"]
for invalid_value in invalid_values:
if invalid_value is None:
continue
point = valid_params.copy()
point["value"] = invalid_value
npt.assert_equal(
pymc_logp(**point),
-np.inf,
err_msg=str(point),
)
def check_logcdf(
pymc_dist: Distribution,
domain: Domain,
paramdomains: Dict[str, Domain],
scipy_logcdf: Callable,
decimal: Optional[int] = None,
n_samples: int = 100,
skip_paramdomain_inside_edge_test: bool = False,
skip_paramdomain_outside_edge_test: bool = False,
) -> None:
"""
Generic test for PyMC logcdf methods
The following tests are performed by default:
1. Test PyMC logcdf and equivalent scipy logcdf methods give similar
results for valid values and parameters inside the supported edges.
Edges are excluded by default, but can be artificially included by
creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
Can be skipped via skip_paramdomain_inside_edge_test
2. Test PyMC logcdf method returns -inf for invalid parameter values
outside the supported edges. Can be skipped via skip_paramdomain_outside_edge_test
3. Test PyMC logcdf method returns -inf and 0 for values below and
above the supported edge, respectively, when using valid parameters.
4. Test PyMC logcdf methods works with multiple value or returns
default informative TypeError
Parameters
----------
pymc_dist: PyMC distribution
domain : Domain
Supported domain of distribution values
paramdomains : Dictionary of Parameter : Domain pairs
Supported domains of distribution parameters
scipy_logcdf : Scipy logcdf method
Scipy logcdf method of equivalent pymc_dist distribution
decimal : Int
Level of precision with which pymc_dist and scipy_logcdf are compared.
Defaults to 6 for float64 and 3 for float32
n_samples : Int
Upper limit on the number of valid domain and value combinations that
are compared between pymc and scipy methods. If n_samples is below the
total number of combinations, a random subset is evaluated. Setting
n_samples = -1, will return all possible combinations. Defaults to 100
skip_paramdomain_inside_edge_test : Bool
Whether to run test 1., which checks that pymc and scipy distributions
match for valid values and parameters inside the respective domain edges
skip_paramdomain_outside_edge_test : Bool
Whether to run test 2., which checks that pymc distribution logcdf
returns -inf for invalid parameter values outside the supported domain edge
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
dist = create_dist_from_paramdomains(pymc_dist, paramdomains)
value = dist.type()
value.name = "value"
dist_logcdf = logcdf(dist, value)
pymc_logcdf = pytensor.function(list(inputvars(dist_logcdf)), dist_logcdf)
# Test pymc and scipy distributions match for values and parameters
# within the supported domain edges (excluding edges)
if not skip_paramdomain_inside_edge_test:
domains = paramdomains.copy()
domains["value"] = domain
for point in product(domains, n_samples=n_samples):
point = dict(point)
npt.assert_almost_equal(
pymc_logcdf(**point),
scipy_logcdf(**point),
decimal=decimal,
err_msg=str(point),
)
valid_value = domain.vals[0]
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
valid_params["value"] = valid_value
# Test pymc distribution raises ParameterValueError for parameters outside the
# supported domain edges (excluding edges)
if not skip_paramdomain_outside_edge_test:
invalid_params = find_invalid_scalar_params(paramdomains)
for invalid_param, invalid_edges in invalid_params.items():
for invalid_edge in invalid_edges:
if invalid_edge is None:
continue
point = valid_params.copy()
point[invalid_param] = invalid_edge
with pytest.raises(ParameterValueError):
pymc_logcdf(**point)
pytest.fail(f"test_params={point}")
# Test that values below domain edge evaluate to -np.inf, and above evaluates to 0
invalid_lower, invalid_upper = find_invalid_scalar_params({"value": domain})["value"]
if invalid_lower is not None:
point = valid_params.copy()
point["value"] = invalid_lower
npt.assert_equal(
pymc_logcdf(**point),
-np.inf,
err_msg=str(point),
)
if invalid_upper is not None:
point = valid_params.copy()
point["value"] = invalid_upper
npt.assert_equal(
pymc_logcdf(**point),
0,
err_msg=str(point),
)
def check_icdf(
pymc_dist: Distribution,
paramdomains: Dict[str, Domain],
scipy_icdf: Callable,
skip_paramdomain_outside_edge_test=False,
decimal: Optional[int] = None,
n_samples: int = 100,
) -> None:
"""
Generic test for PyMC icdf methods
The following tests are performed by default:
1. Test PyMC icdf and equivalent scipy icdf (ppf) methods give similar
results for parameters inside the supported edges.
Edges are excluded by default, but can be artificially included by
creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
2. Test PyMC icdf method raises for invalid parameter values
outside the supported edges.
3. Test PyMC icdf method returns np.nan for values below 0 or above 1,
when using valid parameters.
Parameters
----------
pymc_dist: PyMC distribution
paramdomains : Dictionary of Parameter : Domain pairs
Supported domains of distribution parameters
scipy_icdf : Scipy icdf method
Scipy icdf (ppf) method of equivalent pymc_dist distribution
decimal : int, optional
Level of precision with which pymc_dist and scipy_icdf are compared.
Defaults to 6 for float64 and 3 for float32
n_samples : int
Upper limit on the number of valid domain and value combinations that
are compared between pymc and scipy methods. If n_samples is below the
total number of combinations, a random subset is evaluated. Setting
n_samples = -1, will return all possible combinations. Defaults to 100
skip_paradomain_outside_edge_test : Bool
Whether to run test 2., which checks that pymc distribution icdf
returns nan for invalid parameter values outside the supported domain edge
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
dist = create_dist_from_paramdomains(pymc_dist, paramdomains)
q = pt.scalar(dtype="float64", name="q")
dist_icdf = icdf(dist, q)
pymc_icdf = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)
# Test pymc and scipy distributions match for values and parameters
# within the supported domain edges (excluding edges)
domains = paramdomains.copy()
domain = Domain([0, 0.1, 0.5, 0.75, 0.95, 0.99, 1]) # Values we test the icdf at
domains["q"] = domain
for point in product(domains, n_samples=n_samples):
point = dict(point)
npt.assert_almost_equal(
pymc_icdf(**point),
scipy_icdf(**point),
decimal=decimal,
err_msg=str(point),
)
valid_value = domain.vals[0]
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
valid_params["q"] = valid_value
if not skip_paramdomain_outside_edge_test:
# Test pymc distribution raises ParameterValueError for parameters outside the
# supported domain edges (excluding edges)
invalid_params = find_invalid_scalar_params(paramdomains)
for invalid_param, invalid_edges in invalid_params.items():
for invalid_edge in invalid_edges:
if invalid_edge is None:
continue
point = valid_params.copy()
point[invalid_param] = invalid_edge
with pytest.raises(ParameterValueError):
pymc_icdf(**point)
pytest.fail(f"test_params={point}")
# Test that values below 0 or above 1 evaluate to nan
invalid_values = find_invalid_scalar_params({"q": domain})["q"]
for invalid_value in invalid_values:
if invalid_value is not None:
point = valid_params.copy()
point["q"] = invalid_value
npt.assert_equal(
pymc_icdf(**point),
np.nan,
err_msg=str(point),
)
def check_selfconsistency_discrete_logcdf(
distribution: Distribution,
domain: Domain,
paramdomains: Dict[str, Domain],
decimal: Optional[int] = None,
n_samples: int = 100,
) -> None:
"""
Check that logcdf of discrete distributions matches sum of logps up to value.
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
dist = create_dist_from_paramdomains(distribution, paramdomains)
value = dist.type()
value.name = "value"
dist_logp = logp(dist, value)
dist_logp_fn = pytensor.function(list(inputvars(dist_logp)), dist_logp)
dist_logcdf = logcdf(dist, value)
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
domains = paramdomains.copy()
domains["value"] = domain
for point in product(domains, n_samples=n_samples):
point = dict(point)
value = point.pop("value")
values = np.arange(domain.lower, value + 1)
with pytensor.config.change_flags(mode=Mode("py")):
npt.assert_almost_equal(
dist_logcdf_fn(**point, value=value),
sp.logsumexp([dist_logp_fn(value=value, **point) for value in values]),
decimal=decimal,
err_msg=str(point),
)
def assert_moment_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
)
moment = fn(0)["x"]
expected = np.asarray(expected)
try:
random_draw = model["x"].eval()
except NotImplementedError:
random_draw = moment
assert moment.shape == expected.shape
assert expected.shape == random_draw.shape
assert np.allclose(moment, expected)
if check_finite_logp:
logp_moment = (
transformed_conditional_logp(
(model["x"],),
rvs_to_values={model["x"]: pt.constant(moment)},
rvs_to_transforms={},
)[0]
.sum()
.eval()
)
assert np.isfinite(logp_moment)
def continuous_random_tester(
dist,
paramdomains,
ref_rand,
valuedomain=None,
size=10000,
alpha=0.05,
fails=10,
extra_args=None,
model_args=None,
):
if valuedomain is None:
valuedomain = Domain([0], edges=(None, None))
if model_args is None:
model_args = {}
model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args)
model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
pymc_rand = compile_pymc([], model_dist)
domains = paramdomains.copy()
for point in product(domains, n_samples=100):
point = pm.Point(point, model=model)
point.update(model_args)
# Update the shared parameter variables in `param_vars`
for k, v in point.items():
nv = param_vars.get(k, model.named_vars.get(k))
if nv.name in param_vars:
param_vars[nv.name].set_value(v)
p = alpha
# Allow KS test to fail (i.e., the samples be different)
# a certain number of times. Crude, but necessary.
f = fails
while p <= alpha and f > 0:
s0 = pymc_rand()
s1 = floatX(ref_rand(size=size, **point))
_, p = st.ks_2samp(np.atleast_1d(s0).flatten(), np.atleast_1d(s1).flatten())
f -= 1
assert p > alpha, str(point)
def discrete_random_tester(
dist,
paramdomains,
valuedomain=None,
ref_rand=None,
size=100000,
alpha=0.05,
fails=20,
):
if valuedomain is None:
valuedomain = Domain([0], edges=(None, None))
model, param_vars = build_model(dist, valuedomain, paramdomains)
model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
pymc_rand = compile_pymc([], model_dist)
domains = paramdomains.copy()
for point in product(domains, n_samples=100):
point = pm.Point(point, model=model)
p = alpha
# Update the shared parameter variables in `param_vars`
for k, v in point.items():
nv = param_vars.get(k, model.named_vars.get(k))
if nv.name in param_vars:
param_vars[nv.name].set_value(v)
# Allow Chisq test to fail (i.e., the samples be different)
# a certain number of times.
f = fails
while p <= alpha and f > 0:
o = pymc_rand()
e = intX(ref_rand(size=size, **point))
o = np.atleast_1d(o).flatten()
e = np.atleast_1d(e).flatten()
bins = min(20, max(len(set(e)), len(set(o))))
range = (min(min(e), min(o)), max(max(e), max(o)))
observed, _ = np.histogram(o, bins=bins, range=range)
expected, _ = np.histogram(e, bins=bins, range=range)
if np.all(observed == expected):
p = 1.0
else:
_, p = st.chisquare(observed + 1, expected + 1)
f -= 1
assert p > alpha, str(point)
class SeededTest:
random_seed = 20160911
random_state = None
@classmethod
def setup_class(cls):
nr.seed(cls.random_seed)
def setup_method(self):
nr.seed(self.random_seed)
def get_random_state(self, reset=False):
if self.random_state is None or reset:
self.random_state = nr.RandomState(self.random_seed)
return self.random_state
class BaseTestDistributionRandom(SeededTest):
"""
Base class for tests that new RandomVariables are correctly
implemented, and that the mapping of parameters between the PyMC
Distribution and the respective RandomVariable is correct.
Three default tests are provided which check:
1. Expected inputs are passed to the `rv_op` by the `dist` `classmethod`,
via `check_pymc_params_match_rv_op`
2. Expected (exact) draws are being returned, via
`check_pymc_draws_match_reference`
3. Shape variable inference is correct, via `check_rv_size`
Each desired test must be referenced by name in `checks_to_run`, when
subclassing this distribution. Custom tests can be added to each class as
well. See `TestFlat` for an example.
Additional tests should be added for each optional parametrization of the
distribution. In this case it's enough to include the test
`check_pymc_params_match_rv_op` since only this differs.
Note on `check_rv_size` test:
Custom input sizes (and expected output shapes) can be defined for the
`check_rv_size` test, by adding the optional class attributes
`sizes_to_check` and `sizes_expected`:
```python
sizes_to_check = [None, (1), (2, 3)]
sizes_expected = [(3,), (1, 3), (2, 3, 3)]
checks_to_run = ["check_rv_size"]
```
This is usually needed for Multivariate distributions. You can see an
example in `TestDirichlet`
Notes on `check_pymcs_draws_match_reference` test:
The `check_pymcs_draws_match_reference` is a very simple test for the
equality of draws from the `RandomVariable` and the exact same python
function, given the same inputs and random seed. A small number
(`size=15`) is checked. This is not supposed to be a test for the
correctness of the random generator. The latter kind of test
(if warranted) can be performed with the aid of `pymc_random` and
`pymc_random_discrete` methods in this file, which will perform an
expensive statistical comparison between the RandomVariable `rng_fn`
and a reference Python function. This kind of test only makes sense if
there is a good independent generator reference (i.e., not just the same
composition of numpy / scipy python calls that is done inside `rng_fn`).
Finally, when your `rng_fn` is doing something more than just calling a
`numpy` or `scipy` method, you will need to setup an equivalent seeded
function with which to compare for the exact draws (instead of relying on
`seeded_[scipy|numpy]_distribution_builder`). You can find an example
in the `TestWeibull`, whose `rng_fn` returns
`beta * np.random.weibull(alpha, size=size)`.
"""
pymc_dist: Optional[Callable] = None
pymc_dist_params: Optional[dict] = None
reference_dist: Optional[Callable] = None
reference_dist_params: Optional[dict] = None
expected_rv_op_params: Optional[dict] = None
checks_to_run: List[str] = []
size = 15
decimal = select_by_precision(float64=6, float32=3)
sizes_to_check: Optional[List] = None
sizes_expected: Optional[List] = None
repeated_params_shape = 5
def test_distribution(self):
self.validate_tests_list()
if self.pymc_dist == pm.Wishart:
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
self._instantiate_pymc_rv()
else:
self._instantiate_pymc_rv()
if self.reference_dist is not None:
self.reference_dist_draws = self.reference_dist()(
size=self.size, **self.reference_dist_params
)
for check_name in self.checks_to_run:
if check_name.startswith("test_"):
raise ValueError(
"Custom check cannot start with `test_` or else it will be executed twice."
)
if self.pymc_dist == pm.Wishart and check_name.startswith("check_rv_size"):
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
getattr(self, check_name)()
else:
getattr(self, check_name)()
def _instantiate_pymc_rv(self, dist_params=None):
params = dist_params if dist_params else self.pymc_dist_params
self.pymc_rv = self.pymc_dist.dist(
**params, size=self.size, rng=pytensor.shared(self.get_random_state(reset=True))
)
def check_pymc_draws_match_reference(self):
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
# self._instantiate_pymc_rv()
npt.assert_array_almost_equal(
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
)
def check_pymc_params_match_rv_op(self):
pytensor_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
assert len(self.expected_rv_op_params) == len(pytensor_dist_inputs)
for (expected_name, expected_value), actual_variable in zip(
self.expected_rv_op_params.items(), pytensor_dist_inputs
):
# Add additional line to evaluate symbolic inputs to distributions
if isinstance(expected_value, pytensor.tensor.Variable):
expected_value = expected_value.eval()
npt.assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)
def check_rv_size(self):
# test sizes
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
expected_symbolic = tuple(pymc_rv.shape.eval())
actual = pymc_rv.eval().shape
assert actual == expected_symbolic
assert expected_symbolic == expected
# test multi-parameters sampling for univariate distributions (with univariate inputs)
if (
self.pymc_dist.rv_op.ndim_supp == 0
and self.pymc_dist.rv_op.ndims_params
and sum(self.pymc_dist.rv_op.ndims_params) == 0
):
params = {
k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items()
}
self._instantiate_pymc_rv(params)
sizes_to_check = [None, self.repeated_params_shape, (5, self.repeated_params_shape)]
sizes_expected = [
(self.repeated_params_shape,),
(self.repeated_params_shape,),
(5, self.repeated_params_shape),
]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**params, size=size)
expected_symbolic = tuple(pymc_rv.shape.eval())
actual = pymc_rv.eval().shape
assert actual == expected_symbolic == expected
def validate_tests_list(self):
assert len(self.checks_to_run) == len(
set(self.checks_to_run)
), "There are duplicates in the list of checks_to_run"
def seeded_scipy_distribution_builder(dist_name: str) -> Callable:
return lambda self: ft.partial(getattr(st, dist_name).rvs, random_state=self.get_random_state())
def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
return lambda self: ft.partial(
getattr(np.random.RandomState, dist_name), self.get_random_state()
)
def assert_no_rvs(vars: Sequence[Variable]) -> None:
"""Assert that there are no `MeasurableVariable` nodes in a graph."""
rvs = find_rvs_in_graph(vars)
if rvs:
raise AssertionError(f"RV found in graph: {rvs}")