https://github.com/pymc-devs/pymc3
Tip revision: a24acb9c34ab8cbbfc6f76032d2f666df45250f1 authored by dependabot[bot] on 18 September 2023, 02:45:16 UTC
Bump docker/login-action from 2 to 3
Bump docker/login-action from 2 to 3
Tip revision: a24acb9
test_util.py
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import arviz
import numpy as np
import pytest
import xarray
from cachetools import cached
import pymc as pm
from pymc.distributions.transforms import RVTransform
from pymc.util import (
UNSET,
_get_seeds_per_chain,
dataset_to_point_list,
drop_warning_stat,
get_value_vars_from_user_vars,
hash_key,
hashable,
locally_cachedmethod,
)
class TestTransformName:
cases = [("var", "var_test__"), ("var_test_", "var_test__test__")]
transform_name = "test"
def test_get_transformed_name(self):
class NewTransform(RVTransform):
name = self.transform_name
def forward(self, value):
return 0
def backward(self, value):
return 0
test_transform = NewTransform()
for name, transformed in self.cases:
assert pm.util.get_transformed_name(name, test_transform) == transformed
def test_is_transformed_name(self):
for name, transformed in self.cases:
assert pm.util.is_transformed_name(transformed)
assert not pm.util.is_transformed_name(name)
def test_get_untransformed_name(self):
for name, transformed in self.cases:
assert pm.util.get_untransformed_name(transformed) == name
with pytest.raises(ValueError):
pm.util.get_untransformed_name(name)
class TestExceptions:
def test_shape_error(self):
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError("Just the message.")
assert "Just" in exinfo.value.args[0]
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError("With shapes.", actual=(2, 3))
assert "(2, 3)" in exinfo.value.args[0]
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError("With shapes.", expected="(2,3) or (5,6)")
assert "(5,6)" in exinfo.value.args[0]
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError("With shapes.", actual=(), expected="(5,4) or (?,?,6)")
assert "(?,?,6)" in exinfo.value.args[0]
def test_dtype_error(self):
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError("Just the message.")
assert "Just" in exinfo.value.args[0]
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError("With types.", actual=str)
assert "str" in exinfo.value.args[0]
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError("With types.", expected=float)
assert "float" in exinfo.value.args[0]
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError("With types.", actual=int, expected=str)
assert "int" in exinfo.value.args[0] and "str" in exinfo.value.args[0]
def test_hashing_of_rv_tuples():
obs = np.random.normal(-1, 0.1, size=10)
with pm.Model() as pmodel:
mu = pm.Normal("mu", 0, 1)
sigma = pm.Gamma("sigma", 1, 2)
dd = pm.Normal("dd", observed=obs)
for freerv in [mu, sigma, dd] + pmodel.free_RVs:
for structure in [
freerv,
{"alpha": freerv, "omega": None},
[freerv, []],
(freerv, []),
]:
assert isinstance(hashable(structure), int)
def test_hash_key():
class Bad1:
def __hash__(self):
return 329
class Bad2:
def __hash__(self):
return 329
b1 = Bad1()
b2 = Bad2()
assert b1 != b2
@cached({}, key=hash_key)
def some_func(x):
return x
assert some_func(b1) != some_func(b2)
class TestClass:
@locally_cachedmethod
def some_method(self, x):
return x
tc = TestClass()
assert tc.some_method(b1) != tc.some_method(b2)
def test_unset_repr(capsys):
def fn(a=UNSET):
return
help(fn)
captured = capsys.readouterr()
assert "a=UNSET" in captured.out
def test_dataset_to_point_list():
ds = xarray.Dataset()
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
assert isinstance(pl, list)
assert len(pl) == 6
assert isinstance(pl[0], dict)
assert isinstance(pl[0]["A"], np.ndarray)
# Check that non-str keys are caught
ds[3] = xarray.DataArray([1, 2, 3])
with pytest.raises(ValueError, match="must be str"):
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
def test_drop_warning_stat():
idata = arviz.from_dict(
sample_stats={
"a": np.ones((2, 5, 4)),
"warning": np.ones((2, 5, 3), dtype=object),
},
warmup_sample_stats={
"a": np.ones((2, 5, 4)),
"warning": np.ones((2, 5, 3), dtype=object),
},
attrs=dict(version="0.1.2"),
coords={
"adim": [0, 1, None, 3],
"warning_dim_0": list("ABC"),
},
dims={"a": ["adim"], "warning": ["warning_dim_0"]},
save_warmup=True,
)
new = drop_warning_stat(idata)
assert new is not idata
assert new.attrs.get("version") == "0.1.2"
for gname in ["sample_stats", "warmup_sample_stats"]:
ss = new.get(gname)
assert isinstance(ss, xarray.Dataset), gname
assert "a" in ss
assert "warning" not in ss
assert "warning_dim_0" not in ss
def test_get_seeds_per_chain():
ret = _get_seeds_per_chain(None, chains=1)
assert len(ret) == 1 and isinstance(ret[0], int)
ret = _get_seeds_per_chain(None, chains=2)
assert len(ret) == 2 and isinstance(ret[0], int)
ret = _get_seeds_per_chain(5, chains=1)
assert ret == (5,)
ret = _get_seeds_per_chain(5, chains=3)
assert len(ret) == 3 and isinstance(ret[0], int) and not any(r == 5 for r in ret)
rng = np.random.default_rng(123)
expected_ret = rng.integers(2**30, dtype=np.int64, size=1)
rng = np.random.default_rng(123)
ret = _get_seeds_per_chain(rng, chains=1)
assert ret == expected_ret
rng = np.random.RandomState(456)
expected_ret = rng.randint(2**30, dtype=np.int64, size=2)
rng = np.random.RandomState(456)
ret = _get_seeds_per_chain(rng, chains=2)
assert np.all(ret == expected_ret)
for expected_ret in ([0, 1, 2], (0, 1, 2, 3), np.arange(5)):
ret = _get_seeds_per_chain(expected_ret, chains=len(expected_ret))
assert ret is expected_ret
with pytest.raises(ValueError, match="does not match the number of chains"):
_get_seeds_per_chain(expected_ret, chains=len(expected_ret) + 1)
with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")):
_get_seeds_per_chain({1: 1, 2: 2}, 2)
def test_get_value_vars_from_user_vars():
with pm.Model() as model1:
x1 = pm.Normal("x1", mu=0, sigma=1)
y1 = pm.Normal("y1", mu=0, sigma=1)
x1_value = model1.rvs_to_values[x1]
y1_value = model1.rvs_to_values[y1]
assert get_value_vars_from_user_vars([x1, y1], model1) == [x1_value, y1_value]
assert get_value_vars_from_user_vars([x1], model1) == [x1_value]
# The next line does not wrap the variable in a list on purpose, to test the
# utility function can handle those as promised
assert get_value_vars_from_user_vars(x1_value, model1) == [x1_value]
with pm.Model() as model2:
x2 = pm.Normal("x2", mu=0, sigma=1)
y2 = pm.Normal("y2", mu=0, sigma=1)
det2 = pm.Deterministic("det2", x2 + y2)
prefix = "The following variables are not random variables in the model:"
with pytest.raises(ValueError, match=rf"{prefix} \['x2', 'y2'\]"):
get_value_vars_from_user_vars([x2, y2], model1)
with pytest.raises(ValueError, match=rf"{prefix} \['x2'\]"):
get_value_vars_from_user_vars([x2, y1], model1)
with pytest.raises(ValueError, match=rf"{prefix} \['x2'\]"):
get_value_vars_from_user_vars([x2], model1)
with pytest.raises(ValueError, match=rf"{prefix} \['det2'\]"):
get_value_vars_from_user_vars([det2], model2)