https://github.com/GPflow/GPflow
Raw File
Tip revision: 3b6200307f2133f3c99930850b7e5106a686d656 authored by uri-granta on 07 February 2024, 13:16:15 UTC
Bump version to 2.9.1 in preparation for release (#2103)
Tip revision: 3b62003
test_set_trainable.py
import tensorflow as tf

from gpflow import Parameter, set_trainable


def _module() -> tf.Module:
    class _Mod(tf.Module):
        def __init__(self) -> None:
            super().__init__()
            self.var = tf.Variable(0.0)
            self.param = Parameter(0.0)

    module = _Mod()

    assert len(module.trainable_variables) == 2
    assert len(module.variables) == 2

    return module


def test_can_set_not_trainable() -> None:
    module = _module()
    set_trainable(module, False)
    assert len(module.trainable_variables) == 0


def test_can_set_not_trainable_then_trainable_again() -> None:
    module = _module()
    set_trainable(module, False)
    set_trainable(module, True)
    assert len(module.trainable_variables) == len(module.variables)


def test_can_set_not_trainable_iterable() -> None:
    modules = [_module(), _module(), _module()]
    set_trainable(modules, False)
    assert all(len(m.trainable_variables) == 0 for m in modules)


def test_can_set_not_trainable_then_trainable_iterable() -> None:
    modules = [_module(), _module(), _module()]
    set_trainable(modules, False)
    set_trainable(modules, True)
    assert all(len(m.trainable_variables) == len(m.variables) for m in modules)
back to top