https://github.com/facebookresearch/pythia
Tip revision: ff5c63f8fafa0320d6646340a434497bf7e22718 authored by Vedanuj Goswami on 17 December 2020, 16:21:30 UTC
[docs] Docs for various MMF Transformer configurations
[docs] Docs for various MMF Transformer configurations
Tip revision: ff5c63f
concat_dataset.py
# Copyright (c) Facebook, Inc. and its affiliates.
import functools
import types
from torch.utils.data import ConcatDataset
class MMFConcatDataset(ConcatDataset):
# These functions should only be called once even if they return nothing
_SINGLE_CALL_FUNCS = []
def __init__(self, datasets):
super().__init__(datasets)
self._dir_representation = dir(self)
def __getattr__(self, name):
if "_dir_representation" in self.__dict__ and name in self._dir_representation:
return getattr(self, name)
elif "datasets" in self.__dict__ and hasattr(self.datasets[0], name):
attr = getattr(self.datasets[0], name)
# Check if the current attribute is class method function
if isinstance(attr, types.MethodType):
# if it is the, we to call this function for
# each of the child datasets
attr = functools.partial(self._call_all_datasets_func, name)
return attr
else:
raise AttributeError(name)
def _get_single_call_funcs(self):
return MMFConcatDataset._SINGLE_CALL_FUNCS
def _call_all_datasets_func(self, name, *args, **kwargs):
for dataset in self.datasets:
value = getattr(dataset, name)(*args, **kwargs)
if value is not None:
# TODO: Log a warning here
return value
# raise RuntimeError("Functions returning values can't be "
# "called through MMFConcatDataset")
if (
hasattr(dataset, "get_single_call_funcs")
and name in dataset.get_single_call_funcs()
):
return