# Copyright 2020 The GPflow Contributors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ MonitorTask base classes """ from abc import ABC, abstractmethod from typing import Callable, List, Union import tensorflow as tf __all__ = ["MonitorTask", "ExecuteCallback", "MonitorTaskGroup", "Monitor"] class MonitorTask(ABC): """ A base class for a monitoring task. All monitoring tasks are callable objects. A descendant class must implement the `run` method, which is the body of the monitoring task. """ def __call__(self, step: int, **kwargs): """ It calls the 'run' function and sets the current step. :param step: current step in the optimisation. :param kwargs: additional keyword arguments that can be passed to the `run` method of the task. This is in particular handy for passing keyword argument to the callback of `ScalarToTensorBoard`. """ self.current_step = tf.cast(step, tf.int64) self.run(**kwargs) @abstractmethod def run(self, **kwargs): """ Implements the task to be executed on __call__. The current step is available through `self.current_step`. :param kwargs: keyword arguments available to the run method. """ raise NotImplementedError class ExecuteCallback(MonitorTask): """ Executes a callback as task """ def __init__(self, callback: Callable[..., None]): """ :param callback: callable to be executed during the task. Arguments can be passed using keyword arguments. """ super().__init__() self.callback = callback def run(self, **kwargs): self.callback(**kwargs) class MonitorTaskGroup: """ Class for grouping `MonitorTask` instances. A group defines all the tasks that are run at the same frequency, given by `period`. A `MonitorTaskGroup` can exist of a single instance or a list of `MonitorTask` instances. """ def __init__(self, task_or_tasks: Union[List[MonitorTask], MonitorTask], period: int = 1): """ :param task_or_tasks: a single instance or a list of `MonitorTask` instances. Each `MonitorTask` in the list will be run with the given `period`. :param period: defines how often to run the tasks; they will execute every `period`th step. For large values of `period` the tasks will be less frequently run. Defaults to running at every step (`period = 1`). """ self.tasks = task_or_tasks self._period = period @property def tasks(self) -> List[MonitorTask]: return self._tasks @tasks.setter def tasks(self, task_or_tasks: Union[List[MonitorTask], MonitorTask]) -> None: """Ensures the tasks are stored as a list. Even if there is only a single task.""" if not isinstance(task_or_tasks, List): self._tasks = [task_or_tasks] else: self._tasks = task_or_tasks def __call__(self, step, **kwargs): """Call each task in the group.""" if step % self._period == 0: for task in self.tasks: task(step, **kwargs) class Monitor: r""" Accepts any number of of `MonitorTaskGroup` instances, and runs them according to their specified periodicity. Example use-case: ``` # Create some monitor tasks log_dir = "logs" model_task = ModelToTensorBoard(log_dir, model) image_task = ImageToTensorBoard(log_dir, plot_prediction, "image_samples") lml_task = ScalarToTensorBoard(log_dir, lambda: model.log_marginal_likelihood(), "lml") # Plotting tasks can be quite slow, so we want to run them less frequently. # We group them in a `MonitorTaskGroup` and set the period to 5. slow_tasks = MonitorTaskGroup(image_task, period=5) # The other tasks are fast. We run them at each iteration of the optimisation. fast_tasks = MonitorTaskGroup([model_task, lml_task], period=1) # We pass both groups to the `Monitor` monitor = Monitor(fast_tasks, slow_tasks) ``` """ def __init__(self, *task_groups: MonitorTaskGroup): """ :param task_groups: a list of `MonitorTaskGroup`s to be executed. """ self.task_groups = task_groups def __call__(self, step, **kwargs): for group in self.task_groups: group(step, **kwargs)