iteration_strategies.py
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import warnings
from dataclasses import dataclass
from typing import Dict
import numpy as np
from mmf.common.registry import registry
from mmf.utils.configuration import get_global_config
from mmf.utils.dataset import dataset_list_from_config
from omegaconf import MISSING, OmegaConf
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
class IterationStrategy:
"""
Base class for defining iteration strategies that will be used
for iterating over multiple datasets during multitasking.
An IterationStrategy implementation should `__call__` method
which returns index of dataset from which next batch must be
pulled.
Class can also define `should_exhaust_all_iterators` property
which defines whether all iterators should be exhausted before
reigniting next batch of iterators. For example, in size
proportional iteration strategy, all iterators must be finished
before starting a new round so that all of them get equal
opportunity to present themselves according to their size.
Args:
config (Config): Object of type Config which should be defined
for each iteration strategy for configurable parameters.
dataloaders (Dict[str, DataLoader]): A dictionary containing
mapping from dataset key to its dataloader.
Usage::
from dataclasses import dataclass
from mmf.common.registry import registry
from mmf.datasets.iterators import IterationStrategy
@registry.register_iteration_strategy("my_iteration_strategy")
class MyStrategy(IterationStrategy):
@dataclass
class Config:
name: str = "my_strategy"
def __init__(self, config, dataloader):
...
"""
@dataclass
class Config:
name: str = MISSING
def __init__(
self, config: Config, dataloaders: Dict[str, DataLoader], *args, **kwargs
):
config = OmegaConf.merge(OmegaConf.structured(self.Config), config)
self.config = config
self.dataloaders = dataloaders
@classmethod
def from_params(cls, dataloaders: Dict[str, DataLoader], **kwargs):
config = OmegaConf.structured(cls.Config(**kwargs))
return cls(config, dataloaders)
@property
def should_exhaust_all_iterators(self) -> bool:
return False
def _check_not_epoch_training(self):
"""
Having this allows easy override of the strategy in non-MMF
use cases
"""
training = get_global_config("training")
assert (
training.get("max_epochs", None) is None
), f"{self.__class__.__name__} doesn't make sense with epoch based training"
def __call__(self, *args, **kwargs):
raise NotImplementedError("__call__ hasn't been implemented")
@registry.register_iteration_strategy("constant")
class ConstantIterationStrategy(IterationStrategy):
"""
Always returns a constant number. Useful for mimicing single task
training in multitask setup for verification or defaults purposes
index to be returned can be specified in config parameter as `idx`.
"""
@dataclass
class Config(IterationStrategy.Config):
name: str = "constant"
idx: int = 0
def __init__(
self, config: Config, dataloaders: Dict[str, DataLoader], *args, **kwargs
):
super().__init__(config, dataloaders, *args, **kwargs)
self._idx = self.config.idx
@property
def should_exhaust_all_iterators(self) -> bool:
return True
def __call__(self, *args, **kwargs):
return self._idx
@registry.register_iteration_strategy("round_robin")
class RoundRobinIterationStrategy(IterationStrategy):
"""
Samples datasets one by one in round robin fashion.
Start index can be specified in config as `start_idx`.
Also defaults to size proportional sampling as roundrobin
doesn't make sense with validation and testing splits
as they need to finish one complete epoch.
"""
@dataclass
class Config(IterationStrategy.Config):
name: str = "round_robin"
start_idx: int = 0
def __init__(
self, config: Config, dataloaders: Dict[str, DataLoader], *args, **kwargs
):
super().__init__(config, dataloaders, *args, **kwargs)
self._check_not_epoch_training()
if "start_idx" in self.config:
self._current_idx = self.config.start_idx
def __call__(self, *args, **kwargs):
nxt = self._current_idx
self._current_idx = (self._current_idx + 1) % len(self.dataloaders)
return nxt
@registry.register_iteration_strategy("random")
class RandomIterationStrategy(IterationStrategy):
"""
Samples random number each time when sampled.
Follows test/validation strategy similar to RoundRobin.
"""
@dataclass
class Config(IterationStrategy.Config):
name: str = "random"
def __init__(
self, config: Config, dataloaders: Dict[str, DataLoader], *args, **kwargs
):
super().__init__(config, dataloaders, *args, **kwargs)
self._check_not_epoch_training()
def __call__(self, *args, **kwargs):
choice = np.random.choice(len(self.dataloaders), 1)[0]
return choice
@registry.register_iteration_strategy("size_proportional")
class SizeProportionalIterationStrategy(IterationStrategy):
"""
Samples index based on size of each dataset. Bigger datasets
are sampled more and this strategy requires completing
all iterators before starting new ones. Default in MMF.
"""
@dataclass
class Config(IterationStrategy.Config):
name: str = "size_proportional"
def __init__(
self, config: Config, dataloaders: Dict[str, DataLoader], *args, **kwargs
):
super().__init__(config, dataloaders, *args, **kwargs)
self._per_dataset_lengths = []
self._total_length = 0
for loader in self.dataloaders.values():
# Some loaders might not have dataset attribute
# set, in this case we need to fail gracefully as we can't
# calculate lengths.
assert hasattr(loader, "dataset"), (
"loaders need dataset objects to work with "
+ "'size_proportional' sampling"
)
dataset_instance = loader.dataset
assert hasattr(dataset_instance, "__len__"), (
"all datasets should have __len__ defined "
+ "to work with proportional sampling iterator"
)
dataset_instance_length = len(dataset_instance)
assert (
dataset_instance_length
), f"dataset: {dataset_instance.dataset_type} is empty"
self._per_dataset_lengths.append(dataset_instance_length)
self._total_length += dataset_instance_length
self._dataset_probabilities = self._per_dataset_lengths[:]
self._dataset_probabilities = [
prob / self._total_length for prob in self._dataset_probabilities
]
def __call__(self, *args, **kwargs):
choice = np.random.choice(
len(self.dataloaders), 1, p=self._dataset_probabilities
)[0]
return choice
@property
def should_exhaust_all_iterators(self):
return True
@registry.register_iteration_strategy("ratios")
class RatiosIterationStrategy(IterationStrategy):
"""
Samples based on ratios specified as `sampling_ratios` parameter
in the config. Default to validation/test strategy as in RoundRobin.
`sampling_ratios` defines a dictionary pointing from dataset key to
a floating ration specifying how much the dataset should be sampled.
Floats together should sum to one.
`datasets` is a list of datasets that would be sampled. This should
a subset or same as `sampling_ratios.keys()`.
"""
@dataclass
class Config(IterationStrategy.Config):
name: str = "ratios"
sampling_ratios: Dict[str, float] = MISSING
def __init__(
self, config: Config, dataloaders: Dict[str, DataLoader], *args, **kwargs
):
super().__init__(config, dataloaders, *args, **kwargs)
self._check_not_epoch_training()
given_datasets = self._get_given_datasets()
sampling_ratios = self.config.get("sampling_ratios", {})
probabilities = []
for dataset in given_datasets:
assert (
dataset in sampling_ratios
), f"{dataset} must be specified in sampling_ratios param for multitasking"
probabilities.append(sampling_ratios[dataset])
# normalize the sampling ratios to sum up to 1
prob_sum = sum(probabilities)
assert all(prob >= 0 for prob in probabilities) and prob_sum > 0, (
"sampling_ratios param for multitasking must be all non-negative "
"and at least one of them needs to be positive."
)
self._dataset_probabilities = [prob / prob_sum for prob in probabilities]
logger.info("Using per-dataset sampling probabilities:")
for dataset, prob in zip(given_datasets, self._dataset_probabilities):
logger.info(f"\t{dataset}: {prob}")
def __call__(self, *args, **kwargs):
choice = np.random.choice(
len(self.dataloaders), 1, p=self._dataset_probabilities
)[0]
return choice
def _get_given_datasets(self):
config = registry.get("config")
datasets = None
if config is not None and "datasets" not in config:
datasets = dataset_list_from_config(config)
if datasets is None or len(datasets) == 0:
warnings.warn(
"Either 'datasets' key not in global config or is a empty list. "
+ "Moving forward with dataset list same as sampling ratios"
)
return list(self.config.get("sampling_ratios", {}).keys())
else:
return datasets