Raw File
multi_dataset_loader.py
# Copyright (c) Facebook, Inc. and its affiliates.
"""
MultiDatasetLoader class is used by DatasetLoader class to load multiple datasets
and more granular
"""
import logging
import warnings
from typing import Dict, Iterator

import numpy as np
import torch
from mmf.common.sample import SampleList, convert_batch_to_sample_list
from mmf.datasets import iteration_strategies
from mmf.utils.build import build_dataloader_and_sampler, build_dataset
from mmf.utils.dataset import dataset_list_from_config
from mmf.utils.distributed import (
    broadcast_scalar,
    get_world_size,
    is_dist_initialized,
    is_master,
    is_xla,
)
from mmf.utils.general import get_batch_size, get_current_device
from omegaconf import OmegaConf
from torch.utils.data.dataloader import DataLoader, Sampler


logger = logging.getLogger(__name__)


class MultiDataLoader:
    def __init__(
        self,
        loaders: Dict[str, DataLoader],
        iteration_strategy: iteration_strategies.IterationStrategy = None,
    ):
        if loaders is None or len(loaders) == 0:
            warnings.warn(
                "Empty loaders passed into MultiDataLoader. This can have "
                "unintended consequences."
            )

        if iteration_strategy is None:
            iteration_strategy = iteration_strategies.RoundRobinIterationStrategy(
                OmegaConf.create(), loaders
            )

        self._iteration_strategy = iteration_strategy
        self._loaders = loaders
        self._is_master = is_master()
        self._num_datasets = len(self.loaders)
        self.dataset_list = list(loaders.keys())
        self._iterators = {}
        self._finished_iterators = {}

        self.current_index = 0
        self.set_lengths()
        self.set_samplers()

    def set_lengths(self):
        self._total_length = 0
        for loader in self.loaders.values():
            # Some loaders might not have dataset attribute
            # set, in this case we won't consider them in
            # dataset lengths.
            if not hasattr(loader, "dataset"):
                continue

            dataset_instance = loader.dataset

            if hasattr(dataset_instance, "__len__"):
                dataset_instance_length = len(dataset_instance)
                assert dataset_instance_length, f"dataset: {self.dataset_type} is empty"
                self._total_length += dataset_instance_length

    def set_samplers(self):
        self.samplers: Dict[str, Sampler] = {}
        for key, loader in self.loaders.items():
            if hasattr(loader, "sampler"):
                self.samplers[key] = loader.sampler

    def get_datasets(self):
        return [loader.dataset for loader in self.loaders.values()]

    @property
    def loaders(self) -> Dict[str, DataLoader]:
        return self._loaders

    @property
    def samplers(self) -> Dict[str, Sampler]:
        return self._samplers

    @samplers.setter
    def samplers(self, samplers: Dict[str, Sampler]):
        self._samplers = samplers

    @property
    def num_datasets(self) -> int:
        return self._num_datasets

    @property
    def iterators(self) -> Dict[str, Iterator[SampleList]]:
        return self._iterators

    @iterators.setter
    def iterators(self, iterators: Dict[str, Iterator[SampleList]]):
        self._iterators = iterators

    @property
    def current_loader(self) -> DataLoader:
        return self.loaders[self.current_dataset_name]

    @property
    def iteration_strategy(self) -> iteration_strategies.IterationStrategy:
        return self._iteration_strategy

    @property
    def current_iterator(self) -> DataLoader:
        return self.iterators[self.current_dataset_name]

    @property
    def current_dataset_name(self) -> str:
        return self.dataset_list[self.current_index]

    @property
    def current_dataset(self) -> torch.utils.data.Dataset:
        if hasattr(self.current_loader, "dataset"):
            return self.current_loader.dataset
        else:
            return None

    @property
    def first_loader(self) -> DataLoader:
        return list(self.loaders.values())[0]

    def __len__(self) -> int:
        # Since, this is iterator, we need to return total length == number of batches
        # and as get_batch_size returns per GPU batch size, it needs to be multiplied
        # by world size
        batch_size = get_batch_size() * get_world_size()
        # Changed the length to accomadate drop_last == True
        # drop_last is required if the batch is split into multiple cores
        # some of the cores may not have enough examples.
        if is_xla():
            logging.info(
                "drop_last is set to True to avoid uneven dimension shapes "
                "across cores."
            )
            return (self._total_length) // batch_size
        else:
            # This assumes drop_last=False for all loaders. See also
            # build_dataloader_and_sampler().
            return (self._total_length + batch_size - 1) // batch_size

    def __iter__(self):
        # Clear off old iterators
        self._finished_iterators = {}
        self.iterators = {}

        for key, loader in self.loaders.items():
            self.iterators[key] = iter(loader)

        self.change_dataloader()

        return self

    def __next__(self) -> SampleList:
        """Calculation of next batch is performed using following logic.

        Current chosen iterator is set in the change_dataloader function
        based on the chosen iteration strategy which is called everytime
        prepare_batch is called.

        If we get the next batch from iterator without any StopIteration exception,
        we return it as it is. Otherwise, we have two cases:

        1. In some iteration strategies (example size proportional), each dataset
        needs to same number of epochs at any given time, we need to yield
        StopIteration exception when all iterators are finished. In turn, this
        will yield to __iter__ all reignite all of the iterators. The code will
        not reach __iter__ until unless all iterators are exhausted. An iteration
        strategy should specify this behavior through `should_exhaust_all_iterators`
        property

        2. In other cases of iteration strategies, epochs don't make sense.
        Think of a case of random (equal) proportional sampling for dataset x and y
        where x is half the size of y. When x will complete its 2 epochs, y will
        have only 1 epoch completed. **So please don't use max_epochs or epoch
        based training in this case as it won't be honored**. If an iterator is
        finished, we just reignite it in this case and finished iterators
        variable isn't used. This means that this case will never reach the
        __iter__ function ever again.


        Returns:
            SampleList: sample list instance from currently selected dataset
        """
        try:
            next_batch = next(self.current_iterator)
        except StopIteration:
            if self.iteration_strategy.should_exhaust_all_iterators:
                self._finished_iterators[self.current_dataset_name] = 1

                if len(self._finished_iterators) == self.num_datasets:
                    raise
                else:
                    self.change_dataloader()
                next_batch = next(self.current_iterator)
            else:
                iterator = iter(self.current_loader)
                self.iterators[self.current_dataset_name] = iterator
                next_batch = next(self.current_iterator)

        # Save dataset name and dataset type beforehand as
        # prepare_data will change the current index
        current_dataset_name = self.current_dataset_name
        current_dataset_type = self.current_dataset.dataset_type

        next_batch = self.prepare_batch(next_batch)
        next_batch = convert_batch_to_sample_list(next_batch)

        next_batch.dataset_name = current_dataset_name
        next_batch.dataset_type = current_dataset_type
        return next_batch

    def change_dataloader(self):
        choice = 0

        if self.num_datasets <= 1:
            self.current_index = choice
            return

        if self._is_master:
            choice = self.iteration_strategy()

            # self._finished_iterators will always be empty in case of
            # non-proportional (equal) sampling
            while self.dataset_list[choice] in self._finished_iterators:
                choice = self.iteration_strategy()

        choice = broadcast_scalar(choice, 0, device=get_current_device())
        self.current_index = choice

    def prepare_batch(self, batch: SampleList) -> SampleList:
        if self.current_dataset and hasattr(self.current_dataset, "prepare_batch"):
            batch = self.current_dataset.prepare_batch(batch)

        self.change_dataloader()
        return batch

    def seed_sampler(self, epoch: int):
        if is_dist_initialized():
            for sampler in self.samplers.values():
                if sampler is not None and hasattr(sampler, "set_epoch"):
                    sampler.set_epoch(epoch)


# TODO: Deprecate in favor of MultiDataModule
class MultiDatasetLoader(MultiDataLoader):
    """
    MultiDatasetLoader class that is used for training on multiple datasets together.
    """

    def __init__(self, dataset_type: str = "train"):
        self._dataset_type = dataset_type
        self._datasets = []
        super().__init__({})

    @property
    def dataset_type(self):
        return self._dataset_type

    @property
    def datasets(self):
        return self._datasets

    def load(self, config):
        self.build_datasets(config)
        self.build_dataloaders()
        self.set_lengths()

    def build_datasets(self, config):
        self._datasets = []
        self.config = config
        self._given_datasets = dataset_list_from_config(self.config)

        for dataset in self._given_datasets:
            if dataset in self.config.dataset_config:
                dataset_config = self.config.dataset_config[dataset]
            else:
                warnings.warn(
                    f"Dataset {dataset} is missing from dataset_config"
                    + " in config. Proceeding with empty config."
                )
                dataset_config = OmegaConf.create()

            dataset_instance = build_dataset(dataset, dataset_config, self.dataset_type)
            if dataset_instance is None:
                continue
            self.datasets.append(dataset_instance)
            self.dataset_list.append(dataset)

        self._num_datasets = len(self.datasets)
        self.current_index = 0

        self._infer_dataset_probabilities()

    def build_dataloaders(self):
        assert len(self._datasets) > 0, "Call build_datasets first"

        for dataset_instance in self.datasets:
            loader_instance, _ = build_dataloader_and_sampler(
                dataset_instance, self.config.training
            )
            sampler_instance = loader_instance.sampler
            self.loaders[dataset_instance.name] = loader_instance
            self.samplers[dataset_instance.name] = sampler_instance

        self.current_loader = self.loaders[self.current_dataset_name]

    def verbose_dump(self, *args, **kwargs):
        self._chosen_dataset.verbose_dump(*args, **kwargs)

    # Kept for backwards compatibility for now
    # TODO: Remove in future.
    def _infer_dataset_probabilities(self):
        from mmf.utils.configuration import get_global_config

        training = get_global_config("training")

        proportional_sampling = training.get("dataset_size_proportional_sampling", True)

        if proportional_sampling is True:
            strategy = iteration_strategies.SizeProportionalIterationStrategy
            self._iteration_strategy = strategy(OmegaConf.create(), self.loaders)
        else:
            self._iteration_strategy = iteration_strategies.RandomIterationStrategy(
                OmegaConf.create(), self.loaders
            )

        multitasking = get_global_config("multitasking")
        multitasking_enabled = multitasking.get("enabled", False)

        assert (
            proportional_sampling is True or training.get("max_epochs", None) is None
        ), "Epoch based training can only be used with size proportional sampling"

        assert not (proportional_sampling and multitasking_enabled), (
            "Multitasking (manually-specified) per-dataset ratios cannot be used "
            "with size proportional sampling"
        )

        if multitasking_enabled and "sampling_ratios" in multitasking:
            self._iteration_strategy = iteration_strategies.RatiosIterationStrategy(
                OmegaConf.create(
                    {
                        "sampling_ratios": multitasking.sampling_ratios,
                        "datasets": self._given_datasets,
                    }
                ),
                self._loaders,
            )
        elif proportional_sampling is True:
            strategy = iteration_strategies.SizeProportionalIterationStrategy
            self._iteration_strategy = strategy(OmegaConf.create(), self.loaders)
        else:
            self._iteration_strategy = iteration_strategies.RandomIterationStrategy(
                OmegaConf.create(), self.loaders
            )
back to top