https://github.com/GPflow/GPflow
Revision 7e717ccb69bf4a3e563f176b28ff2f516d65482c authored by st-- on 31 March 2020, 12:17:12 UTC, committed by GitHub on 31 March 2020, 12:17:12 UTC
This gives all `BayesianModel` subclasses a consistent interface both for optimization (MLE/MAP) and MCMC. Models are required to implement `maximum_log_likelihood_objective`, which is to be maximized for model training.

Optimization: The `_training_loss` method is defined as `- (maximum_log_likelihood_objective + log_prior_density)`. This is exposed by the InternalDataTrainingLossMixin and ExternalDataTrainingLossMixin classes.

For models that keep hold of the data internally, `training_loss` can directly be passed as a closure to an optimizer's `minimize`, for example:
```python
model = gpflow.models.GPR(data, ...)
gpflow.optimizers.Scipy().minimize(model.training_loss, model.trainable_variables)
```

If the model objective requires data to be passed in, a closure can be constructed on the fly using `model.training_loss_closure(data)`, which returns a no-argument closure:
```python
model = gpflow.models.SVGP(...)
gpflow.optimizers.Scipy().minimize(
    model.training_loss_closure(data), model.trainable_variables, ...
)
```

The training_loss_closure() method provided by both InternalDataTrainingLossMixin and ExternalDataTrainingLossMixin takes a boolean `compile` argument (default: True) that wraps the returned closure in tf.function(). Note that the return value should be cached in a variable if the minimize() step is run several times to avoid re-compilation in each step!

MCMC: The `log_posterior_density` method can be directly passed to the `SamplingHelper`. By default, `log_posterior_density` is implemented as `maximum_log_likelihood_objective + log_prior_density`. Models can override this if needed. Example:
```python
model = gpflow.models.GPMC(...)
hmc_helper = gpflow.optimizers.SamplingHelper(
    model.log_posterior_density, model.trainable_parameters
)
hmc = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=hmc_helper.target_log_prob_fn, ...
)
```
In this case, the function that runs the MCMC chain should be wrapped in tf.function() (see MCMC notebook).
1 parent e61ee69
Raw File
Tip revision: 7e717ccb69bf4a3e563f176b28ff2f516d65482c authored by st-- on 31 March 2020, 12:17:12 UTC
refactor training objective methods (#1276)
Tip revision: 7e717cc
.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover

# Translations
*.mo
*.pot

# Django stuff:
*.log

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Emacs backups
*~

# Pycharm IDE directory
.idea

# IPython Notebooks
.ipynb_checkpoints

# VSCode
.vscode

# OSX
.DS_Store
back to top