https://github.com/pymc-devs/pymc3
Raw File
Tip revision: 982e4c4202e7c6c1196c59c228fd12903b79d32e authored by Morgan Pihl on 20 December 2022, 18:23:25 UTC
replaces numpy sqrt method with pytensor equivalent (#6405)
Tip revision: 982e4c4
initial_point.py
#   Copyright 2021 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
import functools
import warnings

from typing import Callable, Dict, List, Optional, Sequence, Set, Union

import numpy as np
import pytensor
import pytensor.tensor as at

from pytensor.graph.basic import Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.var import TensorVariable

from pymc.logprob.transforms import RVTransform
from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name

StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]]
PointType = Dict[str, np.ndarray]


def convert_str_to_rv_dict(
    model, start: StartDict
) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]:
    """Helper function for converting a user-provided start dict with str keys of (transformed) variable names
    to a dict mapping the RV tensors to untransformed initvals.
    TODO: Deprecate this functionality and only accept TensorVariables as keys
    """
    initvals = {}
    for key, initval in start.items():
        if isinstance(key, str):
            if is_transformed_name(key):
                rv = model[get_untransformed_name(key)]
                initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
            else:
                initvals[model[key]] = initval
        else:
            initvals[key] = initval
    return initvals


def make_initial_point_fns_per_chain(
    *,
    model,
    overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
    jitter_rvs: Optional[Set[TensorVariable]] = None,
    chains: int,
) -> List[Callable]:
    """Create an initial point function for each chain, as defined by initvals

    If a single initval dictionary is passed, the function is replicated for each
    chain, otherwise a unique function is compiled for each entry in the dictionary.

    Parameters
    ----------
    overrides : optional, list or dict
        Initial value strategy overrides that should take precedence over the defaults from the model.
        A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`.
    jitter_rvs : set, optional
        Random variable tensors for which U(-1, 1) jitter shall be applied.
        (To the transformed space if applicable.)

    Raises
    ------
    ValueError
        If the number of entries in initvals is different than the number of chains

    """
    if isinstance(overrides, dict) or overrides is None:
        # One strategy for all chains
        # Only one function compilation is needed.
        ipfns = [
            make_initial_point_fn(
                model=model,
                overrides=overrides,
                jitter_rvs=jitter_rvs,
                return_transformed=True,
            )
        ] * chains
    elif len(overrides) == chains:
        ipfns = [
            make_initial_point_fn(
                model=model,
                jitter_rvs=jitter_rvs,
                overrides=chain_overrides,
                return_transformed=True,
            )
            for chain_overrides in overrides
        ]
    else:
        raise ValueError(
            f"Number of initval dicts ({len(overrides)}) does not match the number of chains ({chains})."
        )

    return ipfns


def make_initial_point_fn(
    *,
    model,
    overrides: Optional[StartDict] = None,
    jitter_rvs: Optional[Set[TensorVariable]] = None,
    default_strategy: str = "moment",
    return_transformed: bool = True,
) -> Callable:
    """Create seeded function that computes initial values for all free model variables.

    Parameters
    ----------
    jitter_rvs : set
        The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
        added to the initial value. Only available for variables that have a transform or real-valued support.
    default_strategy : str
        Which of { "moment", "prior" } to prefer if the initval setting for an RV is None.
    overrides : dict
        Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
    return_transformed : bool
        If `True` the returned variables will correspond to transformed initial values.
    """

    sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
    initval_strats = {
        **model.rvs_to_initial_values,
        **sdict_overrides,
    }

    initial_values = make_initial_point_expression(
        free_rvs=model.free_RVs,
        rvs_to_transforms=model.rvs_to_transforms,
        initval_strategies=initval_strats,
        jitter_rvs=jitter_rvs,
        default_strategy=default_strategy,
        return_transformed=return_transformed,
    )

    # Replace original rng shared variables so that we don't mess with them
    # when calling the final seeded function
    initial_values = replace_rng_nodes(initial_values)
    func = compile_pymc(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)

    varnames = []
    for var in model.free_RVs:
        transform = model.rvs_to_transforms[var]
        if transform is not None and return_transformed:
            name = get_transformed_name(var.name, transform)
        else:
            name = var.name
        varnames.append(name)

    def make_seeded_function(func):

        rngs = find_rng_nodes(func.maker.fgraph.outputs)

        @functools.wraps(func)
        def inner(seed, *args, **kwargs):
            reseed_rngs(rngs, seed)
            values = func(*args, **kwargs)
            return dict(zip(varnames, values))

        return inner

    return make_seeded_function(func)


def make_initial_point_expression(
    *,
    free_rvs: Sequence[TensorVariable],
    rvs_to_transforms: Dict[TensorVariable, RVTransform],
    initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
    jitter_rvs: Set[TensorVariable] = None,
    default_strategy: str = "moment",
    return_transformed: bool = False,
) -> List[TensorVariable]:
    """Creates the tensor variables that need to be evaluated to obtain an initial point.

    Parameters
    ----------
    free_rvs : list
        Tensors of free random variables in the model.
    rvs_to_values : dict
        Mapping of free random variable tensors to value variable tensors.
    initval_strategies : dict
        Mapping of free random variable tensors to initial value strategies.
        For example the `Model.initial_values` dictionary.
    jitter_rvs : set
        The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
        added to the initial value. Only available for variables that have a transform or real-valued support.
    default_strategy : str
        Which of { "moment", "prior" } to prefer if the initval strategy setting for an RV is None.
    return_transformed : bool
        Switches between returning the tensors for untransformed or transformed initial points.

    Returns
    -------
    initial_points : list of TensorVariable
        PyTensor expressions for initial values of the free random variables.
    """
    from pymc.distributions.distribution import moment

    if jitter_rvs is None:
        jitter_rvs = set()

    initial_values = []
    initial_values_transformed = []

    for variable in free_rvs:
        strategy = initval_strategies.get(variable, None)

        if strategy is None:
            strategy = default_strategy

        if isinstance(strategy, str):
            if strategy == "moment":
                try:
                    value = moment(variable)
                except NotImplementedError:
                    warnings.warn(
                        f"Moment not defined for variable {variable} of type "
                        f"{variable.owner.op.__class__.__name__}, defaulting to "
                        f"a draw from the prior. This can lead to difficulties "
                        f"during tuning. You can manually define an initval or "
                        f"implement a moment dispatched function for this "
                        f"distribution.",
                        UserWarning,
                    )
                    value = variable
            elif strategy == "prior":
                value = variable
            else:
                raise ValueError(
                    f'Invalid string strategy: {strategy}. It must be one of ["moment", "prior"]'
                )
        else:
            value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype)

        transform = rvs_to_transforms.get(variable, None)

        if transform is not None:
            value = transform.forward(value, *variable.owner.inputs)

        if variable in jitter_rvs:
            jitter = at.random.uniform(-1, 1, size=value.shape)
            jitter.name = f"{variable.name}_jitter"
            value = value + jitter

        value = value.astype(variable.dtype)
        initial_values_transformed.append(value)

        if transform is not None:
            value = transform.backward(value, *variable.owner.inputs)

        initial_values.append(value)

    all_outputs: List[TensorVariable] = []
    all_outputs.extend(free_rvs)
    all_outputs.extend(initial_values)
    all_outputs.extend(initial_values_transformed)

    copy_graph = FunctionGraph(outputs=all_outputs, clone=True)

    n_variables = len(free_rvs)
    free_rvs_clone = copy_graph.outputs[:n_variables]
    initial_values_clone = copy_graph.outputs[n_variables:-n_variables]
    initial_values_transformed_clone = copy_graph.outputs[-n_variables:]

    # We now replace all rvs by the respective initial_point expressions
    # in the constrained (untransformed) space. We do this in reverse topological
    # order, so that later nodes do not reintroduce expressions with earlier
    # rvs that would need to once again be replaced by their initial_points
    graph = FunctionGraph(outputs=free_rvs_clone, clone=False)
    replacements = reversed(list(zip(free_rvs_clone, initial_values_clone)))
    graph.replace_all(replacements, import_missing=True)

    if not return_transformed:
        return graph.outputs
    # Because the unconstrained (transformed) expressions are a subgraph of the
    # constrained initial point they were also automatically updated inplace
    # when calling graph.replace_all above, so we don't need to do anything else
    return initial_values_transformed_clone
back to top