https://github.com/pymc-devs/pymc3
Tip revision: 5c600c7117c98d296244c16f27b2fcdb0d37cd70 authored by Ricardo Vieira on 26 July 2023, 09:31:30 UTC
Use graph_replace instead of clone_replace in VI
Use graph_replace instead of clone_replace in VI
Tip revision: 5c600c7
blocking.py
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
pymc.blocking
Classes for working with subsets of parameters.
"""
from __future__ import annotations
from functools import partial
from typing import (
Any,
Callable,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
import numpy as np
from typing_extensions import TypeAlias
__all__ = ["DictToArrayBijection"]
T = TypeVar("T")
PointType: TypeAlias = Dict[str, np.ndarray]
StatsDict: TypeAlias = Dict[str, Any]
StatsType: TypeAlias = List[StatsDict]
StatDtype: TypeAlias = Union[type, np.dtype]
StatShape: TypeAlias = Optional[Sequence[Optional[int]]]
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# each of the raveled variables.
class RaveledVars(NamedTuple):
data: np.ndarray
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]
class Compose(Generic[T]):
"""
Compose two functions in a pickleable way
"""
def __init__(self, fa: Callable[[PointType], T], fb: Callable[[RaveledVars], PointType]):
self.fa = fa
self.fb = fb
def __call__(self, x: RaveledVars) -> T:
return self.fa(self.fb(x))
class DictToArrayBijection:
"""Map between a `dict`s of variables to an array space.
Said array space consists of all the vars raveled and then concatenated.
"""
@staticmethod
def map(var_dict: PointType) -> RaveledVars:
"""Map a dictionary of names and variables to a concatenated 1D array space."""
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
raveled_vars = [v[0].ravel() for v in vars_info]
if raveled_vars:
result = np.concatenate(raveled_vars)
else:
result = np.array([])
return RaveledVars(result, tuple(v[1:] for v in vars_info))
@staticmethod
def rmap(
array: RaveledVars,
start_point: PointType | None = None,
) -> PointType:
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
Parameters
----------
array
The array to map.
start_point
An optional dictionary of initial values.
"""
if start_point:
result = dict(start_point)
else:
result = {}
if not isinstance(array, RaveledVars):
raise TypeError("`array` must be a `RaveledVars` type")
last_idx = 0
for name, shape, dtype in array.point_map_info:
arr_len = np.prod(shape, dtype=int)
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
result[name] = var
last_idx += arr_len
return result
@classmethod
def mapf(
cls, f: Callable[[PointType], T], start_point: PointType | None = None
) -> Callable[[RaveledVars], T]:
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.
function f: DictSpace -> T to ArraySpace -> T
Parameters
----------
f: dict -> T
Returns
-------
f: array -> T
"""
return Compose(f, partial(cls.rmap, start_point=start_point))