https://github.com/pymc-devs/pymc3
Tip revision: 118be0f23782945dc03c5fb36d58d6ce4a1f619f authored by Ricardo Vieira on 07 November 2023, 11:46:49 UTC
Add test for Blockwise logp regression
Add test for Blockwise logp regression
Tip revision: 118be0f
util.py
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
from typing import Any, Dict, List, NewType, Optional, Sequence, Tuple, Union, cast
import arviz
import cloudpickle
import numpy as np
import xarray
from cachetools import LRUCache, cachedmethod
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph.utils import ValidatingScratchpad
from pymc.exceptions import BlockModelAccessError
VarName = NewType("VarName", str)
class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
def __str__(self):
return "UNSET"
def __repr__(self):
return str(self)
UNSET = _UnsetType()
def withparent(meth):
"""Helper wrapper that passes calls to parent's instance"""
def wrapped(self, *args, **kwargs):
res = meth(self, *args, **kwargs)
if getattr(self, "parent", None) is not None:
getattr(self.parent, meth.__name__)(*args, **kwargs)
return res
# Unfortunately functools wrapper fails
# when decorating built-in methods so we
# need to fix that improper behaviour
wrapped.__name__ = meth.__name__
return wrapped
class treelist(list):
"""A list that passes mutable extending operations used in Model
to parent list instance.
Extending treelist you will also extend its parent
"""
def __init__(self, iterable=(), parent=None):
super().__init__(iterable)
assert isinstance(parent, list) or parent is None
self.parent = parent
if self.parent is not None:
self.parent.extend(self)
# typechecking here works bad
append = withparent(list.append)
__iadd__ = withparent(list.__iadd__)
extend = withparent(list.extend)
def tree_contains(self, item):
if isinstance(self.parent, treedict):
return list.__contains__(self, item) or self.parent.tree_contains(item)
elif isinstance(self.parent, list):
return list.__contains__(self, item) or self.parent.__contains__(item)
else:
return list.__contains__(self, item)
def __setitem__(self, key, value):
raise NotImplementedError(
"Method is removed as we are not able to determine appropriate logic for it"
)
# Added this because mypy didn't like having __imul__ without __mul__
# This is my best guess about what this should do. I might be happier
# to kill both of these if they are not used.
def __mul__(self, other) -> "treelist":
return cast("treelist", super().__mul__(other))
def __imul__(self, other) -> "treelist":
t0 = len(self)
super().__imul__(other)
if self.parent is not None:
self.parent.extend(self[t0:])
return self # python spec says should return the result.
class treedict(dict):
"""A dict that passes mutable extending operations used in Model
to parent dict instance.
Extending treedict you will also extend its parent
"""
def __init__(self, iterable=(), parent=None, **kwargs):
super().__init__(iterable, **kwargs)
assert isinstance(parent, dict) or parent is None
self.parent = parent
if self.parent is not None:
self.parent.update(self)
# typechecking here works bad
__setitem__ = withparent(dict.__setitem__)
update = withparent(dict.update)
def tree_contains(self, item):
# needed for `add_named_variable` method
if isinstance(self.parent, treedict):
return dict.__contains__(self, item) or self.parent.tree_contains(item)
elif isinstance(self.parent, dict):
return dict.__contains__(self, item) or self.parent.__contains__(item)
else:
return dict.__contains__(self, item)
def get_transformed_name(name, transform):
r"""
Consistent way of transforming names
Parameters
----------
name: str
Name to transform
transform: transforms.Transform
Should be a subclass of `transforms.Transform`
Returns
-------
str
A string to use for the transformed variable
"""
return f"{name}_{transform.name}__"
def is_transformed_name(name):
r"""
Quickly check if a name was transformed with `get_transformed_name`
Parameters
----------
name: str
Name to check
Returns
-------
bool
Boolean, whether the string could have been produced by `get_transformed_name`
"""
return name.endswith("__") and name.count("_") >= 3
def get_untransformed_name(name):
r"""
Undo transformation in `get_transformed_name`. Throws ValueError if name wasn't transformed
Parameters
----------
name: str
Name to untransform
Returns
-------
str
String with untransformed version of the name.
"""
if not is_transformed_name(name):
raise ValueError(f"{name} does not appear to be a transformed name")
return "_".join(name.split("_")[:-3])
def get_default_varnames(var_iterator, include_transformed):
r"""Helper to extract default varnames from a trace.
Parameters
----------
varname_iterator: iterator
Elements will be cast to string to check whether it is transformed, and optionally filtered
include_transformed: boolean
Should transformed variable names be included in return value
Returns
-------
list
List of variables, possibly filtered
"""
if include_transformed:
return list(var_iterator)
else:
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]
def get_var_name(var) -> VarName:
"""Get an appropriate, plain variable name for a variable."""
return VarName(str(getattr(var, "name", var)))
def get_transformed(z):
if hasattr(z, "transformed"):
z = z.transformed
return z
def biwrap(wrapper):
@functools.wraps(wrapper)
def enhanced(*args, **kwargs):
is_bound_method = hasattr(args[0], wrapper.__name__) if args else False
if is_bound_method:
count = 1
else:
count = 0
if len(args) > count:
newfn = wrapper(*args, **kwargs)
return newfn
else:
newwrapper = functools.partial(wrapper, *args, **kwargs)
return newwrapper
return enhanced
def dataset_to_point_list(
ds: xarray.Dataset, sample_dims: Sequence[str]
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
# All keys of the dataset must be a str
var_names = list(ds.keys())
for vn in var_names:
if not isinstance(vn, str):
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
num_sample_dims = len(sample_dims)
stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims}
ds = ds.transpose(*sample_dims, ...)
stacked_dict = {
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items()
}
points = [
{vn: stacked_dict[vn][i, ...] for vn in var_names}
for i in range(np.prod([len(coords) for coords in stacked_dims.values()]))
]
# use the list of points
return cast(List[Dict[str, np.ndarray]], points), stacked_dims
def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData:
"""Returns a new ``InferenceData`` object with the "warning" stat removed from sample stats groups.
This function should be applied to an ``InferenceData`` object obtained with
``pm.sample(keep_warning_stat=True)`` before trying to ``.to_netcdf()`` or ``.to_zarr()`` it.
"""
nidata = arviz.InferenceData(attrs=idata.attrs)
for gname, group in idata.items():
if "sample_stat" in gname:
group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore")
nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims)
return nidata
def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]:
"""Extract and return number of chains and samples in xarray or arviz traces."""
dataset: xarray.Dataset
if isinstance(data, xarray.Dataset):
dataset = data
elif isinstance(data, arviz.InferenceData):
dataset = data["posterior"]
else:
raise ValueError(
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
data.__class__,
)
coords = dataset.coords
nchains = coords["chain"].sizes["chain"]
nsamples = coords["draw"].sizes["draw"]
return nchains, nsamples
def hashable(a=None) -> int:
"""
Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function.
Lists and tuples are hashed based on their elements.
"""
if isinstance(a, dict):
# first hash the keys and values with hashable
# then hash the tuple of int-tuples with the builtin
return hash(tuple((hashable(k), hashable(v)) for k, v in a.items()))
if isinstance(a, (tuple, list)):
# lists are mutable and not hashable by default
# for memoization, we need the hash to depend on the items
return hash(tuple(hashable(i) for i in a))
try:
return hash(a)
except TypeError:
pass
# Not hashable >>>
try:
return hash(cloudpickle.dumps(a))
except Exception:
if hasattr(a, "__dict__"):
return hashable(a.__dict__)
else:
return id(a)
def hash_key(*args, **kwargs):
return tuple(HashableWrapper(a) for a in args + tuple(kwargs.items()))
class HashableWrapper:
__slots__ = ("obj",)
def __init__(self, obj):
self.obj = obj
def __hash__(self):
return hashable(self.obj)
def __eq__(self, other):
return self.obj == other
def __repr__(self):
return f"{type(self).__name__}({self.obj})"
class WithMemoization:
def __hash__(self):
return hash(id(self))
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_cache", None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
def locally_cachedmethod(f):
from collections import defaultdict
def self_cache_fn(f_name):
def cf(self):
return self.__dict__.setdefault("_cache", defaultdict(lambda: LRUCache(128)))[f_name]
return cf
return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f)
def check_dist_not_registered(dist, model=None):
"""Check that a dist is not registered in the model already"""
from pymc.model import modelcontext
try:
model = modelcontext(None)
except (TypeError, BlockModelAccessError):
pass
else:
if dist in model.basic_RVs:
raise ValueError(
f"The dist {dist} was already registered in the current model.\n"
f"You should use an unregistered (unnamed) distribution created via "
f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`"
)
def point_wrapper(core_function):
"""Wrap an pytensor compiled function to be able to ingest point dictionaries whilst
ignoring the keys that are not valid inputs to the core function.
"""
ins = [i.name for i in core_function.maker.fgraph.inputs if not isinstance(i, SharedVariable)]
def wrapped(**kwargs):
input_point = {k: v for k, v in kwargs.items() if k in ins}
return core_function(**input_point)
return wrapped
RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]]
RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator]
def _get_seeds_per_chain(
random_state: RandomState,
chains: int,
) -> Union[Sequence[int], np.ndarray]:
"""Obtain or validate specified integer seeds per chain.
This function process different possible sources of seeding and returns one integer
seed per chain:
1. If the input is an integer and a single chain is requested, the input is
returned inside a tuple.
2. If the input is a sequence or NumPy array with as many entries as chains,
the input is returned.
3. If the input is an integer and multiple chains are requested, new unique seeds
are generated from NumPy default Generator seeded with that integer.
4. If the input is None new unique seeds are generated from an unseeded NumPy default
Generator.
5. If a RandomState or Generator is provided, new unique seeds are generated from it.
Raises
------
ValueError
If none of the conditions above are met
"""
def _get_unique_seeds_per_chain(integers_fn):
seeds = []
while len(set(seeds)) != chains:
seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)]
return seeds
if random_state is None or isinstance(random_state, int):
if chains == 1 and isinstance(random_state, int):
return (random_state,)
return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers)
if isinstance(random_state, np.random.Generator):
return _get_unique_seeds_per_chain(random_state.integers)
if isinstance(random_state, np.random.RandomState):
return _get_unique_seeds_per_chain(random_state.randint)
if not isinstance(random_state, (list, tuple, np.ndarray)):
raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.")
if len(random_state) != chains:
raise ValueError(
f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})."
)
return random_state
def get_value_vars_from_user_vars(
vars: Union[Variable, Sequence[Variable]], model
) -> List[Variable]:
"""Converts user "vars" input into value variables.
More often than not, users will pass random variables, and we will extract the
respective value variables, but we also allow for the input to already be value
variables, in case the function is called internally or by a "super-user"
Returns
-------
value_vars: list of TensorVariable
List of model value variables that correspond to the input vars
Raises
------
ValueError:
If any of the provided variables do not correspond to any model value variable
"""
if not isinstance(vars, Sequence):
# Single var was passed
value_vars = [model.rvs_to_values.get(vars, vars)]
else:
value_vars = [model.rvs_to_values.get(var, var) for var in vars]
# Check that we only have value vars from the model
model_value_vars = model.value_vars
notin = [v for v in value_vars if v not in model_value_vars]
if notin:
notin = list(map(get_var_name, notin))
# We mention random variables, even though the input may be a wrong value variable
# because most users don't know about that duality
raise ValueError(
"The following variables are not random variables in the model: " + str(notin)
)
return value_vars
class _FutureWarningValidatingScratchpad(ValidatingScratchpad):
def __getattribute__(self, name):
for deprecated_names, alternative in (
(("value_var", "observations"), "model.rvs_to_values[rv]"),
(("transform",), "model.rvs_to_transforms[rv]"),
):
if name in deprecated_names:
try:
super().__getattribute__(name)
except AttributeError:
pass
else:
warnings.warn(
f"The tag attribute {name} is deprecated. Use {alternative} instead",
FutureWarning,
)
return super().__getattribute__(name)
def _add_future_warning_tag(var) -> None:
old_tag = var.tag
if not isinstance(old_tag, _FutureWarningValidatingScratchpad):
new_tag = _FutureWarningValidatingScratchpad("test_value", var.type.filter)
for k, v in old_tag.__dict__.items():
new_tag.__dict__.setdefault(k, v)
var.tag = new_tag
def makeiter(a):
if isinstance(a, (tuple, list)):
return a
else:
return [a]