https://github.com/pymc-devs/pymc3
Raw File
Tip revision: 118be0f23782945dc03c5fb36d58d6ce4a1f619f authored by Ricardo Vieira on 07 November 2023, 11:46:49 UTC
Add test for Blockwise logp regression
Tip revision: 118be0f
testing.py
#   Copyright 2023 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
import functools as ft
import itertools as it

from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from numpy import random as nr
from numpy import testing as npt
from pytensor.compile.mode import Mode
from pytensor.graph.basic import Variable
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor import TensorVariable
from scipy import special as sp
from scipy import stats as st

import pymc as pm

from pymc.distributions.distribution import Distribution
from pymc.distributions.shape_utils import change_dist_size
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp
from pymc.logprob.utils import (
    ParameterValueError,
    local_check_parameter_to_ninf_switch,
    rvs_in_graph,
)
from pymc.pytensorf import compile_pymc, floatX, inputvars, intX

# This mode can be used for tests where model compilations takes the bulk of the runtime
# AND where we don't care about posterior numerical or sampling stability (e.g., when
# all that matters are the shape of the draws or deterministic values of observed data).
# DO NOT USE UNLESS YOU HAVE A GOOD REASON TO!
fast_unstable_sampling_mode = (
    pytensor.compile.mode.FAST_COMPILE
    # Remove slow rewrite phases
    .excluding("canonicalize", "specialize")
    # Include necessary rewrites for proper logp handling
    .including("remove_TransformedVariables").register(
        (in2out(local_check_parameter_to_ninf_switch), -1)
    )
)


def product(domains, n_samples=-1):
    """Get an iterator over a product of domains.

    Args:
        domains: a dictionary of (name, object) pairs, where the objects
                 must be "domain-like", as in, have a `.vals` property
        n_samples: int, maximum samples to return.  -1 to return whole product

    Returns
    -------
        list of the cartesian product of the domains
    """
    try:
        names, domains = zip(*domains.items())
    except ValueError:  # domains.items() is empty
        return [{}]
    all_vals = [zip(names, val) for val in it.product(*(d.vals for d in domains))]
    if n_samples > 0 and len(all_vals) > n_samples:
        return (all_vals[j] for j in nr.choice(len(all_vals), n_samples, replace=False))
    return all_vals


class Domain:
    def __init__(self, vals, dtype=pytensor.config.floatX, edges=None, shape=None):
        # Infinity values must be kept as floats
        vals = [np.array(v, dtype=dtype) if np.all(np.isfinite(v)) else floatX(v) for v in vals]

        if edges is None:
            edges = np.array(vals[0]), np.array(vals[-1])
            vals = vals[1:-1]
        else:
            edges = list(edges)
            if edges[0] is None:
                edges[0] = np.full_like(vals[0], -np.inf)
            if edges[1] is None:
                edges[1] = np.full_like(vals[0], np.inf)
            edges = tuple(edges)

        if not vals:
            raise ValueError(
                f"Domain has no values left after removing edges: {edges}.\n"
                "You can duplicate the edge values or explicitly specify the edges with the edge keyword.\n"
                f"For example: `Domain([{edges[0]}, {edges[0]}, {edges[1]}, {edges[1]}])`"
            )

        if shape is None:
            shape = vals[0].shape

        self.vals = vals
        self.shape = shape
        self.lower, self.upper = edges
        self.dtype = dtype

    def __add__(self, other):
        return Domain(
            [v + other for v in self.vals],
            self.dtype,
            (self.lower + other, self.upper + other),
            self.shape,
        )

    def __mul__(self, other):
        try:
            return Domain(
                [v * other for v in self.vals],
                self.dtype,
                (self.lower * other, self.upper * other),
                self.shape,
            )
        except TypeError:
            return Domain(
                [v * other for v in self.vals],
                self.dtype,
                (self.lower, self.upper),
                self.shape,
            )

    def __neg__(self):
        return Domain([-v for v in self.vals], self.dtype, (-self.lower, -self.upper), self.shape)


class ProductDomain:
    def __init__(self, domains):
        self.vals = list(it.product(*(d.vals for d in domains)))
        self.shape = (len(domains),) + domains[0].shape
        self.lower = [d.lower for d in domains]
        self.upper = [d.upper for d in domains]
        self.dtype = domains[0].dtype


def Vector(D, n):
    return ProductDomain([D] * n)


def SortedVector(n):
    vals = []
    np.random.seed(42)
    for _ in range(10):
        vals.append(np.sort(np.random.randn(n)))
    return Domain(vals, edges=(None, None))


def UnitSortedVector(n):
    vals = []
    np.random.seed(42)
    for _ in range(10):
        vals.append(np.sort(np.random.rand(n)))
    return Domain(vals, edges=(None, None))


def RealMatrix(n, m):
    vals = []
    np.random.seed(42)
    for _ in range(10):
        vals.append(np.random.randn(n, m))
    return Domain(vals, edges=(None, None))


def simplex_values(n):
    if n == 1:
        yield np.array([1.0])
    else:
        for v in Unit.vals:
            for vals in simplex_values(n - 1):
                yield np.concatenate([[v], (1 - v) * vals])


def Simplex(n):
    return Domain(simplex_values(n), shape=(n,), dtype=Unit.dtype, edges=(None, None))


def MultiSimplex(n_dependent, n_independent):
    vals = []
    for simplex_value in it.product(simplex_values(n_dependent), repeat=n_independent):
        vals.append(np.vstack(simplex_value))

    return Domain(vals, dtype=Unit.dtype, shape=(n_independent, n_dependent))


def RandomPdMatrix(n):
    A = np.random.rand(n, n)
    return np.dot(A, A.T) + n * np.identity(n)


R = Domain([-np.inf, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.inf])
Rplus = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 100, np.inf])
Rplusbig = Domain([0, 0.5, 0.9, 0.99, 1, 1.5, 2, 20, np.inf])
Rminusbig = Domain([-np.inf, -2, -1.5, -1, -0.99, -0.9, -0.5, -0.01, 0])
Unit = Domain([0, 0.001, 0.1, 0.5, 0.75, 0.99, 1])
Circ = Domain([-np.pi, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.pi])
Runif = Domain([-np.inf, -0.4, 0, 0.4, np.inf])
Rdunif = Domain([-np.inf, -1, 0, 1, np.inf], "int64")
Rplusunif = Domain([0, 0.5, np.inf])
Rplusdunif = Domain([0, 10, np.inf], "int64")
I = Domain([-np.inf, -3, -2, -1, 0, 1, 2, 3, np.inf], "int64")
NatSmall = Domain([0, 3, 4, 5, np.inf], "int64")
Nat = Domain([0, 1, 2, 3, np.inf], "int64")
NatBig = Domain([0, 1, 2, 3, 5000, np.inf], "int64")
PosNat = Domain([1, 2, 3, np.inf], "int64")
Bool = Domain([0, 0, 1, 1], "int64")


def select_by_precision(float64, float32):
    """Helper function to choose reasonable decimal cutoffs for different floatX modes."""
    decimal = float64 if pytensor.config.floatX == "float64" else float32
    return decimal


def build_model(distfam, valuedomain, vardomains, extra_args=None):
    if extra_args is None:
        extra_args = {}

    with pm.Model() as m:
        param_vars = {}
        for v, dom in vardomains.items():
            v_pt = pytensor.shared(np.asarray(dom.vals[0]))
            v_pt.name = v
            param_vars[v] = v_pt
        param_vars.update(extra_args)
        distfam(
            "value",
            **param_vars,
            transform=None,
        )
    return m, param_vars


def create_dist_from_paramdomains(
    pymc_dist: Distribution,
    paramdomains: Dict[str, Domain],
    extra_args: Optional[Dict[str, Any]] = None,
) -> TensorVariable:
    """Create a PyMC distribution from a dictionary of parameter domains.

    Returns
    -------
        PyMC distribution variable: TensorVariable
        Value variable: TensorVariable
    """
    if extra_args is None:
        extra_args = {}

    param_vars = {}
    for param, domain in paramdomains.items():
        param_type = pt.constant(np.asarray(domain.vals[0])).type()
        param_type.name = param
        param_vars[param] = param_type

    return pymc_dist.dist(**param_vars, **extra_args)


def find_invalid_scalar_params(
    paramdomains: Dict["str", Domain]
) -> Dict["str", Tuple[Union[None, float], Union[None, float]]]:
    """Find invalid parameter values from bounded scalar parameter domains.

    For use in `check_logp`-like testing helpers.

    Returns
    -------
    Invalid paramemeter values:
        Dictionary mapping each parameter, to a lower and upper invalid values (out of domain).
        If no lower or upper invalid values exist, None is returned for that entry.
    """
    invalid_params = {}
    for param, paramdomain in paramdomains.items():
        lower_edge, upper_edge = None, None

        if np.ndim(paramdomain.lower) == 0:
            if np.isfinite(paramdomain.lower):
                lower_edge = paramdomain.lower - 1

            if np.isfinite(paramdomain.upper):
                upper_edge = paramdomain.upper + 1

        invalid_params[param] = (lower_edge, upper_edge)
    return invalid_params


def check_logp(
    pymc_dist: Distribution,
    domain: Domain,
    paramdomains: Dict[str, Domain],
    scipy_logp: Callable,
    decimal: Optional[int] = None,
    n_samples: int = 100,
    extra_args: Optional[Dict[str, Any]] = None,
    scipy_args: Optional[Dict[str, Any]] = None,
    skip_paramdomain_outside_edge_test: bool = False,
) -> None:
    """
    Generic test for PyMC logp methods

    Test PyMC logp and equivalent scipy logpmf/logpdf methods give similar
    results for valid values and parameters inside the supported edges.
    Edges are excluded by default, but can be artificially included by
    creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)

    Parameters
    ----------
    pymc_dist: PyMC distribution
    domain : Domain
        Supported domain of distribution values
    paramdomains : Dictionary of Parameter : Domain pairs
        Supported domains of distribution parameters
    scipy_logp : Scipy logpmf/logpdf method
        Scipy logp method of equivalent pymc_dist distribution
    decimal : Int
        Level of precision with which pymc_dist and scipy logp are compared.
        Defaults to 6 for float64 and 3 for float32
    n_samples : Int
        Upper limit on the number of valid domain and value combinations that
        are compared between pymc and scipy methods. If n_samples is below the
        total number of combinations, a random subset is evaluated. Setting
        n_samples = -1, will return all possible combinations. Defaults to 100
    extra_args : Dictionary with extra arguments needed to build pymc model
        Dictionary is passed to helper function `build_model` from which
        the pymc distribution logp is calculated
    scipy_args : Dictionary with extra arguments needed to call scipy logp method
        Usually the same as extra_args
    """
    if decimal is None:
        decimal = select_by_precision(float64=6, float32=3)

    if scipy_args is None:
        scipy_args = {}

    def scipy_logp_with_scipy_args(**args):
        args.update(scipy_args)
        return scipy_logp(**args)

    dist = create_dist_from_paramdomains(pymc_dist, paramdomains, extra_args)
    value = dist.type()
    value.name = "value"
    pymc_dist_logp = logp(dist, value).sum()
    pymc_logp = pytensor.function(list(inputvars(pymc_dist_logp)), pymc_dist_logp)

    # Test supported value and parameters domain matches Scipy
    domains = paramdomains.copy()
    domains["value"] = domain
    for point in product(domains, n_samples=n_samples):
        point = dict(point)
        npt.assert_almost_equal(
            pymc_logp(**point),
            scipy_logp_with_scipy_args(**point),
            decimal=decimal,
            err_msg=str(point),
        )

    valid_value = domain.vals[0]
    valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
    valid_params["value"] = valid_value

    # Test pymc distribution raises ParameterValueError for scalar parameters outside
    # the supported domain edges (excluding edges)
    if not skip_paramdomain_outside_edge_test:
        invalid_params = find_invalid_scalar_params(paramdomains)

        for invalid_param, invalid_edges in invalid_params.items():
            for invalid_edge in invalid_edges:
                if invalid_edge is None:
                    continue

                point = valid_params.copy()  # Shallow copy should be okay
                point[invalid_param] = invalid_edge
                with pytest.raises(ParameterValueError):
                    pymc_logp(**point)
                    pytest.fail(f"test_params={point}")

    # Test that values outside of scalar domain support evaluate to -np.inf
    invalid_values = find_invalid_scalar_params({"value": domain})["value"]

    for invalid_value in invalid_values:
        if invalid_value is None:
            continue

        point = valid_params.copy()
        point["value"] = invalid_value
        npt.assert_equal(
            pymc_logp(**point),
            -np.inf,
            err_msg=str(point),
        )


def check_logcdf(
    pymc_dist: Distribution,
    domain: Domain,
    paramdomains: Dict[str, Domain],
    scipy_logcdf: Callable,
    decimal: Optional[int] = None,
    n_samples: int = 100,
    skip_paramdomain_inside_edge_test: bool = False,
    skip_paramdomain_outside_edge_test: bool = False,
) -> None:
    """
    Generic test for PyMC logcdf methods

    The following tests are performed by default:
        1. Test PyMC logcdf and equivalent scipy logcdf methods give similar
        results for valid values and parameters inside the supported edges.
        Edges are excluded by default, but can be artificially included by
        creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
        Can be skipped via skip_paramdomain_inside_edge_test
        2. Test PyMC logcdf method returns -inf for invalid parameter values
        outside the supported edges. Can be skipped via skip_paramdomain_outside_edge_test
        3. Test PyMC logcdf method returns -inf and 0 for values below and
        above the supported edge, respectively, when using valid parameters.
        4. Test PyMC logcdf methods works with multiple value or returns
        default informative TypeError

    Parameters
    ----------
    pymc_dist: PyMC distribution
    domain : Domain
        Supported domain of distribution values
    paramdomains : Dictionary of Parameter : Domain pairs
        Supported domains of distribution parameters
    scipy_logcdf : Scipy logcdf method
        Scipy logcdf method of equivalent pymc_dist distribution
    decimal : Int
        Level of precision with which pymc_dist and scipy_logcdf are compared.
        Defaults to 6 for float64 and 3 for float32
    n_samples : Int
        Upper limit on the number of valid domain and value combinations that
        are compared between pymc and scipy methods. If n_samples is below the
        total number of combinations, a random subset is evaluated. Setting
        n_samples = -1, will return all possible combinations. Defaults to 100
    skip_paramdomain_inside_edge_test : Bool
        Whether to run test 1., which checks that pymc and scipy distributions
        match for valid values and parameters inside the respective domain edges
    skip_paramdomain_outside_edge_test : Bool
        Whether to run test 2., which checks that pymc distribution logcdf
        returns -inf for invalid parameter values outside the supported domain edge

    """
    if decimal is None:
        decimal = select_by_precision(float64=6, float32=3)

    dist = create_dist_from_paramdomains(pymc_dist, paramdomains)
    value = dist.type()
    value.name = "value"
    dist_logcdf = logcdf(dist, value)
    pymc_logcdf = pytensor.function(list(inputvars(dist_logcdf)), dist_logcdf)

    # Test pymc and scipy distributions match for values and parameters
    # within the supported domain edges (excluding edges)
    if not skip_paramdomain_inside_edge_test:
        domains = paramdomains.copy()
        domains["value"] = domain
        for point in product(domains, n_samples=n_samples):
            point = dict(point)
            npt.assert_almost_equal(
                pymc_logcdf(**point),
                scipy_logcdf(**point),
                decimal=decimal,
                err_msg=str(point),
            )

    valid_value = domain.vals[0]
    valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
    valid_params["value"] = valid_value

    # Test pymc distribution raises ParameterValueError for parameters outside the
    # supported domain edges (excluding edges)
    if not skip_paramdomain_outside_edge_test:
        invalid_params = find_invalid_scalar_params(paramdomains)

        for invalid_param, invalid_edges in invalid_params.items():
            for invalid_edge in invalid_edges:
                if invalid_edge is None:
                    continue

                point = valid_params.copy()
                point[invalid_param] = invalid_edge
                with pytest.raises(ParameterValueError):
                    pymc_logcdf(**point)
                    pytest.fail(f"test_params={point}")

    # Test that values below domain edge evaluate to -np.inf, and above evaluates to 0
    invalid_lower, invalid_upper = find_invalid_scalar_params({"value": domain})["value"]
    if invalid_lower is not None:
        point = valid_params.copy()
        point["value"] = invalid_lower
        npt.assert_equal(
            pymc_logcdf(**point),
            -np.inf,
            err_msg=str(point),
        )
    if invalid_upper is not None:
        point = valid_params.copy()
        point["value"] = invalid_upper
        npt.assert_equal(
            pymc_logcdf(**point),
            0,
            err_msg=str(point),
        )


def check_icdf(
    pymc_dist: Distribution,
    paramdomains: Dict[str, Domain],
    scipy_icdf: Callable,
    skip_paramdomain_outside_edge_test=False,
    decimal: Optional[int] = None,
    n_samples: int = 100,
) -> None:
    """
    Generic test for PyMC icdf methods

    The following tests are performed by default:
        1. Test PyMC icdf and equivalent scipy icdf (ppf) methods give similar
        results for parameters inside the supported edges.
        Edges are excluded by default, but can be artificially included by
        creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
        2. Test PyMC icdf method raises for invalid parameter values
        outside the supported edges.
        3. Test PyMC icdf method returns np.nan for values below 0 or above 1,
         when using valid parameters.

    Parameters
    ----------
    pymc_dist: PyMC distribution
    paramdomains : Dictionary of Parameter : Domain pairs
        Supported domains of distribution parameters
    scipy_icdf : Scipy icdf method
        Scipy icdf (ppf) method of equivalent pymc_dist distribution
    decimal : int, optional
        Level of precision with which pymc_dist and scipy_icdf are compared.
        Defaults to 6 for float64 and 3 for float32
    n_samples : int
        Upper limit on the number of valid domain and value combinations that
        are compared between pymc and scipy methods. If n_samples is below the
        total number of combinations, a random subset is evaluated. Setting
        n_samples = -1, will return all possible combinations. Defaults to 100
    skip_paradomain_outside_edge_test : Bool
        Whether to run test 2., which checks that pymc distribution icdf
        returns nan for invalid parameter values outside the supported domain edge

    """
    if decimal is None:
        decimal = select_by_precision(float64=6, float32=3)

    dist = create_dist_from_paramdomains(pymc_dist, paramdomains)
    q = pt.scalar(dtype="float64", name="q")
    dist_icdf = icdf(dist, q)
    pymc_icdf = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)

    # Test pymc and scipy distributions match for values and parameters
    # within the supported domain edges (excluding edges)
    domains = paramdomains.copy()
    domain = Domain([0, 0.1, 0.5, 0.75, 0.95, 0.99, 1])  # Values we test the icdf at
    domains["q"] = domain

    for point in product(domains, n_samples=n_samples):
        point = dict(point)
        npt.assert_almost_equal(
            pymc_icdf(**point),
            scipy_icdf(**point),
            decimal=decimal,
            err_msg=str(point),
        )

    valid_value = domain.vals[0]
    valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
    valid_params["q"] = valid_value

    if not skip_paramdomain_outside_edge_test:
        # Test pymc distribution raises ParameterValueError for parameters outside the
        # supported domain edges (excluding edges)
        invalid_params = find_invalid_scalar_params(paramdomains)
        for invalid_param, invalid_edges in invalid_params.items():
            for invalid_edge in invalid_edges:
                if invalid_edge is None:
                    continue

                point = valid_params.copy()
                point[invalid_param] = invalid_edge
                with pytest.raises(ParameterValueError):
                    pymc_icdf(**point)
                    pytest.fail(f"test_params={point}")

    # Test that values below 0 or above 1 evaluate to nan
    invalid_values = find_invalid_scalar_params({"q": domain})["q"]
    for invalid_value in invalid_values:
        if invalid_value is not None:
            point = valid_params.copy()
            point["q"] = invalid_value
            npt.assert_equal(
                pymc_icdf(**point),
                np.nan,
                err_msg=str(point),
            )


def check_selfconsistency_discrete_logcdf(
    distribution: Distribution,
    domain: Domain,
    paramdomains: Dict[str, Domain],
    decimal: Optional[int] = None,
    n_samples: int = 100,
) -> None:
    """
    Check that logcdf of discrete distributions matches sum of logps up to value.
    """
    if decimal is None:
        decimal = select_by_precision(float64=6, float32=3)

    dist = create_dist_from_paramdomains(distribution, paramdomains)
    value = dist.type()
    value.name = "value"
    dist_logp = logp(dist, value)
    dist_logp_fn = pytensor.function(list(inputvars(dist_logp)), dist_logp)

    dist_logcdf = logcdf(dist, value)
    dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)

    domains = paramdomains.copy()
    domains["value"] = domain

    for point in product(domains, n_samples=n_samples):
        point = dict(point)
        value = point.pop("value")
        values = np.arange(domain.lower, value + 1)

        with pytensor.config.change_flags(mode=Mode("py")):
            npt.assert_almost_equal(
                dist_logcdf_fn(**point, value=value),
                sp.logsumexp([dist_logp_fn(value=value, **point) for value in values]),
                decimal=decimal,
                err_msg=str(point),
            )


def assert_moment_is_expected(model, expected, check_finite_logp=True):
    fn = make_initial_point_fn(
        model=model,
        return_transformed=False,
        default_strategy="moment",
    )
    moment = fn(0)["x"]
    expected = np.asarray(expected)
    try:
        random_draw = model["x"].eval()
    except NotImplementedError:
        random_draw = moment

    assert moment.shape == expected.shape
    assert expected.shape == random_draw.shape
    assert np.allclose(moment, expected)

    if check_finite_logp:
        logp_moment = (
            transformed_conditional_logp(
                (model["x"],),
                rvs_to_values={model["x"]: pt.constant(moment)},
                rvs_to_transforms={},
            )[0]
            .sum()
            .eval()
        )
        assert np.isfinite(logp_moment)


def continuous_random_tester(
    dist,
    paramdomains,
    ref_rand,
    valuedomain=None,
    size=10000,
    alpha=0.05,
    fails=10,
    extra_args=None,
    model_args=None,
):
    if valuedomain is None:
        valuedomain = Domain([0], edges=(None, None))

    if model_args is None:
        model_args = {}

    model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args)
    model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
    pymc_rand = compile_pymc([], model_dist)

    domains = paramdomains.copy()
    for point in product(domains, n_samples=100):
        point = pm.Point(point, model=model)
        point.update(model_args)

        # Update the shared parameter variables in `param_vars`
        for k, v in point.items():
            nv = param_vars.get(k, model.named_vars.get(k))
            if nv.name in param_vars:
                param_vars[nv.name].set_value(v)

        p = alpha
        # Allow KS test to fail (i.e., the samples be different)
        # a certain number of times. Crude, but necessary.
        f = fails
        while p <= alpha and f > 0:
            s0 = pymc_rand()
            s1 = floatX(ref_rand(size=size, **point))
            _, p = st.ks_2samp(np.atleast_1d(s0).flatten(), np.atleast_1d(s1).flatten())
            f -= 1
        assert p > alpha, str(point)


def discrete_random_tester(
    dist,
    paramdomains,
    valuedomain=None,
    ref_rand=None,
    size=100000,
    alpha=0.05,
    fails=20,
):
    if valuedomain is None:
        valuedomain = Domain([0], edges=(None, None))

    model, param_vars = build_model(dist, valuedomain, paramdomains)
    model_dist = change_dist_size(model.named_vars["value"], size, expand=True)
    pymc_rand = compile_pymc([], model_dist)

    domains = paramdomains.copy()
    for point in product(domains, n_samples=100):
        point = pm.Point(point, model=model)
        p = alpha

        # Update the shared parameter variables in `param_vars`
        for k, v in point.items():
            nv = param_vars.get(k, model.named_vars.get(k))
            if nv.name in param_vars:
                param_vars[nv.name].set_value(v)

        # Allow Chisq test to fail (i.e., the samples be different)
        # a certain number of times.
        f = fails
        while p <= alpha and f > 0:
            o = pymc_rand()
            e = intX(ref_rand(size=size, **point))
            o = np.atleast_1d(o).flatten()
            e = np.atleast_1d(e).flatten()
            bins = min(20, max(len(set(e)), len(set(o))))
            range = (min(min(e), min(o)), max(max(e), max(o)))
            observed, _ = np.histogram(o, bins=bins, range=range)
            expected, _ = np.histogram(e, bins=bins, range=range)
            if np.all(observed == expected):
                p = 1.0
            else:
                _, p = st.chisquare(observed + 1, expected + 1)
            f -= 1
        assert p > alpha, str(point)


class BaseTestDistributionRandom:
    """
    Base class for tests that new RandomVariables are correctly
    implemented, and that the mapping of parameters between the PyMC
    Distribution and the respective RandomVariable is correct.

    Three default tests are provided which check:
    1. Expected inputs are passed to the `rv_op` by the `dist` `classmethod`,
    via `check_pymc_params_match_rv_op`
    2. Expected (exact) draws are being returned, via
    `check_pymc_draws_match_reference`
    3. Shape variable inference is correct, via `check_rv_size`

    Each desired test must be referenced by name in `checks_to_run`, when
    subclassing this distribution. Custom tests can be added to each class as
    well. See `TestFlat` for an example.

    Additional tests should be added for each optional parametrization of the
    distribution. In this case it's enough to include the test
    `check_pymc_params_match_rv_op` since only this differs.

    Note on `check_rv_size` test:
        Custom input sizes (and expected output shapes) can be defined for the
        `check_rv_size` test, by adding the optional class attributes
        `sizes_to_check` and `sizes_expected`:

        ```python
        sizes_to_check = [None, (1), (2, 3)]
        sizes_expected = [(3,), (1, 3), (2, 3, 3)]
        checks_to_run = ["check_rv_size"]
        ```

        This is usually needed for Multivariate distributions. You can see an
        example in `TestDirichlet`


    Notes on `check_pymcs_draws_match_reference` test:

        The `check_pymcs_draws_match_reference` is a very simple test for the
        equality of draws from the `RandomVariable` and the exact same python
        function, given the  same inputs and random seed. A small number
        (`size=15`) is checked. This is not supposed to be a test for the
        correctness of the random generator. The latter kind of test
        (if warranted) can be performed with the aid of `pymc_random` and
        `pymc_random_discrete` methods in this file, which will perform an
        expensive statistical comparison between the RandomVariable `rng_fn`
        and a reference Python function. This kind of test only makes sense if
        there is a good independent generator reference (i.e., not just the same
        composition of numpy / scipy python calls that is done inside `rng_fn`).

        Finally, when your `rng_fn` is doing something more than just calling a
        `numpy` or `scipy` method, you will need to setup an equivalent seeded
        function with which to compare for the exact draws (instead of relying on
        `seeded_[scipy|numpy]_distribution_builder`). You can find an example
        in the `TestWeibull`, whose `rng_fn` returns
        `beta * np.random.weibull(alpha, size=size)`.

    """

    pymc_dist: Optional[Callable] = None
    pymc_dist_params: Optional[dict] = None
    reference_dist: Optional[Callable] = None
    reference_dist_params: Optional[dict] = None
    expected_rv_op_params: Optional[dict] = None
    checks_to_run: List[str] = []
    size = 15
    decimal = select_by_precision(float64=6, float32=3)

    sizes_to_check: Optional[List] = None
    sizes_expected: Optional[List] = None
    repeated_params_shape = 5
    random_state = None

    def test_distribution(self):
        self.validate_tests_list()
        if self.pymc_dist == pm.Wishart:
            with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
                self._instantiate_pymc_rv()
        else:
            self._instantiate_pymc_rv()
        if self.reference_dist is not None:
            self.reference_dist_draws = self.reference_dist()(
                size=self.size, **self.reference_dist_params
            )
        for check_name in self.checks_to_run:
            if check_name.startswith("test_"):
                raise ValueError(
                    "Custom check cannot start with `test_` or else it will be executed twice."
                )
            if self.pymc_dist == pm.Wishart and check_name.startswith("check_rv_size"):
                with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
                    getattr(self, check_name)()
            else:
                getattr(self, check_name)()

    def get_random_state(self, reset=False):
        if self.random_state is None or reset:
            self.random_state = nr.RandomState(20160911)
        return self.random_state

    def _instantiate_pymc_rv(self, dist_params=None):
        params = dist_params if dist_params else self.pymc_dist_params
        self.pymc_rv = self.pymc_dist.dist(
            **params, size=self.size, rng=pytensor.shared(self.get_random_state(reset=True))
        )

    def check_pymc_draws_match_reference(self):
        # need to re-instantiate it to make sure that the order of drawings match the reference distribution one
        # self._instantiate_pymc_rv()
        npt.assert_array_almost_equal(
            self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
        )

    def check_pymc_params_match_rv_op(self):
        pytensor_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
        assert len(self.expected_rv_op_params) == len(pytensor_dist_inputs)
        for (expected_name, expected_value), actual_variable in zip(
            self.expected_rv_op_params.items(), pytensor_dist_inputs
        ):
            # Add additional line to evaluate symbolic inputs to distributions
            if isinstance(expected_value, pytensor.tensor.Variable):
                expected_value = expected_value.eval()

            npt.assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)

    def check_rv_size(self):
        # test sizes
        sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
        sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
        for size, expected in zip(sizes_to_check, sizes_expected):
            pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
            expected_symbolic = tuple(pymc_rv.shape.eval())
            actual = pymc_rv.eval().shape
            assert actual == expected_symbolic
            assert expected_symbolic == expected

        # test multi-parameters sampling for univariate distributions (with univariate inputs)
        if (
            self.pymc_dist.rv_op.ndim_supp == 0
            and self.pymc_dist.rv_op.ndims_params
            and sum(self.pymc_dist.rv_op.ndims_params) == 0
        ):
            params = {
                k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items()
            }
            self._instantiate_pymc_rv(params)
            sizes_to_check = [None, self.repeated_params_shape, (5, self.repeated_params_shape)]
            sizes_expected = [
                (self.repeated_params_shape,),
                (self.repeated_params_shape,),
                (5, self.repeated_params_shape),
            ]
            for size, expected in zip(sizes_to_check, sizes_expected):
                pymc_rv = self.pymc_dist.dist(**params, size=size)
                expected_symbolic = tuple(pymc_rv.shape.eval())
                actual = pymc_rv.eval().shape
                assert actual == expected_symbolic == expected

    def validate_tests_list(self):
        assert len(self.checks_to_run) == len(
            set(self.checks_to_run)
        ), "There are duplicates in the list of checks_to_run"


def seeded_scipy_distribution_builder(dist_name: str) -> Callable:
    return lambda self: ft.partial(getattr(st, dist_name).rvs, random_state=self.get_random_state())


def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
    return lambda self: ft.partial(
        getattr(np.random.RandomState, dist_name), self.get_random_state()
    )


def assert_no_rvs(vars: Sequence[Variable]) -> None:
    """Assert that there are no `MeasurableVariable` nodes in a graph."""

    rvs = rvs_in_graph(vars)
    if rvs:
        raise AssertionError(f"RV found in graph: {rvs}")
back to top