Revision aeb0ab472296c0298c2b007c30af2705a75a89f8 authored by ST John on 18 June 2019, 09:46:26 UTC, committed by ST John on 18 June 2019, 09:48:10 UTC
1 parent 4ad6260
Raw File
paramlist.py
# Copyright 2016 James Hensman, Mark van der Wilk,
#                Valentine Svensson, alexggmatthews,
#                PabloLeon, fujiisoup
# Copyright 2017 Artem Artemev @awav
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ..core.tensor_converter import TensorConverter
from ..core.errors import GPflowError
from ..core.compilable import Build

from .parameter import Parameter
from .parameterized import Parameterized


class ParamList(Parameterized):
    """
    ParamList is special case of parameterized object. It implements different
    access pattern for its childres. Instead saving node like objects as attributes
    it keeps them in the list, providing indexed access and acts as simple list.

    :param list_of_params: list of node like objects.
    :param trainable: Boolean flag. It indicates whether children parameters
        should be trainable or not.
    :param name: ParamList name.
    """

    def __init__(self, list_of_params, trainable=True, name=None):
        super(ParamList, self).__init__(name=None)
        if not isinstance(list_of_params, list):
            raise ValueError('Not acceptable argument type at list_of_params.')
        self._list = []
        for value in list_of_params:
            item = self._valid_list_input(value, trainable)
            self.append(item)

    @property
    def _children(self):
        return {str(i): v for i, v in enumerate(self._list)}

    def _store_child(self, index, child):
        if index == len(self):
            self._list.append(child)
        else:
            self._list[index] = child

    @property
    def params(self):
        for item in self._list:
            yield item

    def append(self, item):
        value = self._valid_list_input(item, self.trainable)
        index = len(self)
        self._set_node(index, value)

    def _get_node(self, name):
        return self._list[name]

    def _replace_node(self, index, old, new):
        old._set_parent()
        self._set_child(index, new)

    def _valid_list_input(self, value, trainable):
        if not Parameterized._is_param_like(value):
            try:
                return Parameter(value, trainable=trainable)
            except ValueError:
                raise ValueError(
                    'A list item must be either parameter, '
                    'tensorflow variable, an array or a scalar.')
        return value

    def __len__(self):
        return len(self._list)

    def __getitem__(self, index):
        e = self._get_node(index)
        if TensorConverter.tensor_mode(self) and isinstance(e, Parameter):
            return Parameterized._tensor_mode_parameter(e)
        return e

    def __setitem__(self, index, value):
        if not isinstance(value, Parameter):
            raise ValueError(
                'Non parameter type cannot be assigned to the list.')
        if not self.empty and self.is_built_coherence(value.graph) is Build.YES:
            raise GPflowError(
                'ParamList is compiled and items are not modifiable.')
        self._update_node(index, value)
back to top