https://github.com/GPflow/GPflow
Tip revision: 36c1a01f1d980454782cc1b19c90e88a2c71f6fe authored by Artem Artemev on 07 October 2019, 08:43:59 UTC
Update docstring of the helper
Update docstring of the helper
Tip revision: 36c1a01
mcmc.py
# Copyright 2019 Artem Artemev @awav
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, List, Optional, TypeVar
import tensorflow as tf
import gpflow
__all__ = ["positive_parameter", "SamplingHelper"]
def positive_parameter(value: tf.Tensor):
if isinstance(value, (tf.Variable, gpflow.Parameter)):
return value
return gpflow.Parameter(value, transform=gpflow.positive())
ModelParameters = List[TypeVar("ModelParameter", tf.Variable, gpflow.Parameter)]
@dataclass(frozen=True)
class SamplingHelper:
"""
Helper reads from variables being set with a prior and writes values back to the same variables.
Example:
model = <Create GPflow model>
hmc_helper = SamplingHelper(lambda: -model.neg_log_marginal_likelihood(), model.trainable_parameters)
target_fn = hmc_helper.make_posterior_log_prob_fn()
@tf.function
def run_chain_fn():
hmc = mcmc.HamiltonianMonteCarlo(target_log_prob_fn=target_fn, num_leapfrog_steps=10, step_size=0.01)
adaptive_hmc = mcmc.SimpleStepSizeAdaptation(hmc, num_adaptation_steps=10, target_accept_prob=0.75,
adaptation_rate=0.1)
return mcmc.sample_chain(num_results=hmc_parameters.num_samples, num_burnin_steps=100,
current_state=hmc_helper.variables, kernel=adaptive_hmc)
Args:
target_log_prob_fn: Python callable which represents log-density under the target distribution.
parameters: List of `Variable`'s or gpflow `Parameter`s used as a state of the Markov chain.
"""
target_log_prob_fn: Callable[[ModelParameters], tf.Tensor]
parameters: ModelParameters
@property
def variables(self):
"""
Returns the same list of parameters as `parameters` property, but replaces gpflow `Parameter`s
with their unconstrained variables - `parameter.unconstrained_variable`.
"""
return [p.unconstrained_variable if isinstance(p, gpflow.Parameter) else p for p in self.parameters]
def assign_values(self, *values, unconstrained: Optional[bool] = True):
"""
Assings (constrained or unconstrained) values to the parameter's variable.
Unconstrained values are assigned to the list of `variables` property.
"""
trainables = self.variables if unconstrained else self.parameters
assert len(values) == len(trainables)
n = len(trainables)
for i in range(n):
trainables[i].assign(values[i])
def convert_to_constrained_values(self, *unconstrained_values):
"""
Converts list of `unconstrained_values` to constrained versions. Each value in the list correspond to the
parameter and in case when an object in the same position has `gpflow.Parameter` type, the `forward` method
of transform will be applied.
"""
samples = []
for i, values in enumerate(unconstrained_values):
param = self.parameters[i]
if isinstance(param, gpflow.Parameter) and param.transform is not None:
sample = param.transform.forward(values)
else:
sample = values
samples.append(sample.numpy())
return samples
def make_posterior_log_prob_fn(self):
"""
Make a differentiable posterior log-probability function using helper's `target_log_prob_fn` with respect to
passed `parameters`.
"""
@tf.custom_gradient
def log_prob_fn(*values):
self.assign_values(*values)
variables_to_watch = self.variables
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(variables_to_watch)
log_prob = self.target_log_prob_fn()
@tf.function
def grad_fn(in_grad: tf.Tensor, variables: Optional[tf.Variable] = None):
grad = tape.gradient(log_prob, variables_to_watch)
return grad, [None] * len(variables)
return log_prob, grad_fn
return log_prob_fn