https://github.com/facebookresearch/pythia
Tip revision: dabf95f523cd07e93380c6931e5140ade0f50b2f authored by Sethu Sankaran on 26 October 2021, 19:18:43 UTC
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Tip revision: dabf95f
build.py
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os
import warnings
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import mmf
import pytorch_lightning as pl
import torch
from mmf.common.meter import Meter
from mmf.common.registry import registry
from mmf.datasets.iteration_strategies import (
ConstantIterationStrategy,
IterationStrategy,
SizeProportionalIterationStrategy,
)
from mmf.datasets.processors.processors import Processor
from mmf.utils.configuration import Configuration, get_global_config
from mmf.utils.distributed import is_dist_initialized, is_main, is_xla, synchronize
from mmf.utils.general import get_optimizer_parameters
from omegaconf import DictConfig, OmegaConf
from packaging import version
try:
import torch_xla.core.xla_model as xm # noqa
import torch_xla.distributed.parallel_loader as xla_pl # noqa
except ImportError:
xm = None
ProcessorDict = Dict[str, Processor]
logger = logging.getLogger(__name__)
def build_config(configuration: Configuration, *args, **kwargs) -> DictConfig:
"""Builder function for config. Freezes the configuration and registers
configuration object and config DictConfig object to registry.
Args:
configuration (Configuration): Configuration object that will be
used to create the config.
Returns:
(DictConfig): A config which is of type omegaconf.DictConfig
"""
configuration.freeze()
config = configuration.get_config()
registry.register("config", config)
registry.register("configuration", configuration)
return config
def build_trainer(config: DictConfig) -> Any:
"""Builder function for creating a trainer class. Trainer class name
is picked from the config.
Args:
config (DictConfig): Configuration that will be used to create
the trainer.
Returns:
(BaseTrainer): A trainer instance
"""
trainer_type = config.training.trainer
trainer_cls = registry.get_trainer_class(trainer_type)
trainer_obj = trainer_cls(config)
return trainer_obj
def build_lightning_model(
config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"],
checkpoint_path: str = None,
) -> "mmf.models.base_model.BaseModel":
from mmf.models.base_model import BaseModel
if not checkpoint_path:
model = build_model(config)
model.is_pl_enabled = True
return model
# If it is not an OmegaConf object, create the object
if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config):
config = OmegaConf.structured(config)
model_name = config.model
model_class = registry.get_model_class(model_name)
if model_class is None:
raise RuntimeError(f"No model registered for name: {model_name}")
""" model.build is called inside on_load_checkpoint as suggested here:
https://github.com/PyTorchLightning/pytorch-lightning/issues/5410
"""
if is_main():
model_class.load_requirements(model_class, config=config)
model = model_class.load_from_checkpoint(
checkpoint_path, config=config, strict=False
)
synchronize()
else:
synchronize()
model = model_class.load_from_checkpoint(
checkpoint_path, config=config, strict=False
)
model.init_losses()
model.is_pl_enabled = True
return model
def build_model(
config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"],
) -> "mmf.models.base_model.BaseModel":
from mmf.models.base_model import BaseModel
# If it is not an OmegaConf object, create the object
if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config):
config = OmegaConf.structured(config)
model_name = config.model
model_class = registry.get_model_class(model_name)
if model_class is None:
raise RuntimeError(f"No model registered for name: {model_name}")
model = model_class(config)
if hasattr(model, "build"):
"""Model build involves checkpoint loading
If the checkpoint is not available the underlying
methods try to download it.
Let master build the model (download the checkpoints) while
other ranks wait for the sync message
Once the master has downloaded the checkpoint and built the
model it sends the sync message, completing the synchronization
now other cores can proceed to build the model
using already downloaded checkpoint.
"""
if is_main():
model_class.load_requirements(model_class, config=config)
model.build()
synchronize()
else:
synchronize()
model.build()
model.init_losses()
return model
def build_dataset(
dataset_key: str, config=None, dataset_type="train"
) -> torch.utils.data.Dataset:
"""Builder function for creating a dataset. If dataset_key is passed
the dataset is created from default config of the dataset and thus is
disable config even if it is passed. Otherwise, we use MultiDatasetLoader to
build and return an instance of dataset based on the config
Args:
dataset_key (str): Key of dataset to build.
config (DictConfig, optional): Configuration that will be used to create
the dataset. If not passed, dataset's default config will be used.
Defaults to {}.
dataset_type (str, optional): Type of the dataset to build, train|val|test.
Defaults to "train".
Returns:
(torch.utils.data.Dataset): A dataset instance of type torch Dataset
"""
from mmf.datasets.base_dataset_builder import BaseDatasetBuilder
from mmf.utils.configuration import load_yaml_with_defaults
datamodule_instance = build_datamodule(dataset_key)
# If config is not provided, we take it from default one
if not config:
config_path = datamodule_instance.config_path()
if config_path is None:
# If config path wasn't defined, send an empty config path
# but don't force dataset to define a config
warnings.warn(
f"Config path not defined for {dataset_key}, "
+ "continuing with empty config"
)
config = OmegaConf.create()
else:
config = load_yaml_with_defaults(config_path)
config = OmegaConf.select(config, f"dataset_config.{dataset_key}")
if config is None:
config = OmegaConf.create()
OmegaConf.set_struct(config, True)
elif dataset_key in config:
# Handle Global config
config = config[dataset_key]
datamodule_instance.build_dataset(config)
dataset = datamodule_instance.load_dataset(config, dataset_type)
if hasattr(datamodule_instance, "update_registry_for_model"):
datamodule_instance.update_registry_for_model(config)
return dataset
# TODO: move dataset_type enum to typings
def build_datasets(
dataset_list: List[str], dataset_config: DictConfig, dataset_type="train"
) -> List[torch.utils.data.Dataset]:
datasets = []
for dataset in dataset_list:
if dataset in dataset_config:
dataset_config = dataset_config[dataset]
else:
warnings.warn(
f"Dataset {dataset} is missing from dataset_config"
+ " in config. Proceeding with empty config."
)
dataset_config = OmegaConf.create()
dataset_instance = build_dataset(dataset, dataset_config, dataset_type)
if dataset_instance is None:
continue
datasets.append(dataset_instance)
return datasets
def build_datamodule(dataset_key) -> pl.LightningDataModule:
dataset_builder = registry.get_builder_class(dataset_key)
assert dataset_builder, (
f"Key {dataset_key} doesn't have a registered " + "dataset builder"
)
builder_instance: pl.LightningDataModule = dataset_builder()
return builder_instance
def build_multiple_datamodules(
dataset_list: List[str], all_dataset_config: DictConfig
) -> Dict[str, pl.LightningDataModule]:
datamodules: Dict[str, pl.LightningDataModule] = {}
for dataset in dataset_list:
datamodule_instance = build_datamodule(dataset)
if dataset in all_dataset_config:
dataset_config = all_dataset_config[dataset]
else:
warnings.warn(
f"Dataset {dataset} is missing from dataset_config"
+ " in config. Proceeding with empty config."
)
dataset_config = OmegaConf.create()
if is_main():
datamodule_instance.prepare_data(dataset_config)
synchronize()
datamodule_instance.setup(config=dataset_config)
if hasattr(datamodule_instance, "update_registry_for_model"):
datamodule_instance.update_registry_for_model(dataset_config)
datamodules[dataset] = datamodule_instance
return datamodules
def build_dataloader_and_sampler(
dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
"""Builds and returns a dataloader along with its sample
Args:
dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
dataloader has to be created
datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
for infering params for dataloader
Returns:
Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
Tuple of Dataloader and Sampler instance
"""
from mmf.common.batch_collator import BatchCollator
training_config = get_global_config("training")
# Support params coming in from dataloader params
other_args = {
"num_workers": datamodule_config.get(
"num_workers", training_config.get("num_workers", 4)
),
"pin_memory": datamodule_config.get(
"pin_memory", training_config.get("pin_memory", False)
),
"shuffle": datamodule_config.get("shuffle", None),
"batch_size": datamodule_config.get("batch_size", None),
}
if version.parse(torch.__version__) >= version.parse("1.8"):
# only use persistent workers in PyTorch 1.8 or higher
# (PyTorch 1.7 also has this option but doesn't support it correctly due to
# https://github.com/pytorch/pytorch/issues/48370)
other_args["persistent_workers"] = (
datamodule_config.get(
"persistent_workers", training_config.get("persistent_workers", True)
),
)
if other_args["persistent_workers"] and other_args["num_workers"] == 0:
logger.warning(
"persistent_workers cannot be used together with num_workers == 0; "
"setting persistent_workers to False"
)
other_args["persistent_workers"] = False
# IterableDataset returns batches directly, so no need to add Sampler
# or batch size as user is expected to control those. This is a fine
# assumption for now to not support single item based IterableDataset
# as it will add unnecessary complexity and config parameters
# to the codebase
if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)
else:
other_args.pop("shuffle")
# Set drop_last=True when using XLA to have constant batch size.
# In this case we also need to set drop_last=True in DistributedSampler.
loader = torch.utils.data.DataLoader(
dataset=dataset_instance,
collate_fn=BatchCollator(
dataset_instance.dataset_name, dataset_instance.dataset_type
),
drop_last=is_xla(), # see also MultiDatasetLoader.__len__
**other_args,
)
if is_xla():
device = xm.xla_device()
loader = xla_pl.MpDeviceLoader(loader, device)
if other_args["num_workers"] >= 0:
# Suppress leaking semaphore warning
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
loader.dataset_type = dataset_instance.dataset_type
return loader, other_args.get("sampler", None)
def build_test_reporter(
datamodules: List[pl.LightningDataModule],
config: DictConfig = None,
dataset_type: str = "train",
):
test_reporter_key = "default"
if config:
test_reporter_key = config.get("type", "default")
test_reporter_class = registry.get_test_rerporter_class(test_reporter_key)
assert (
test_reporter_class
), f"Key {test_reporter_key} doesn't have a registered test_reporter class"
if not config:
warnings.warn(
f"Config not provided for {test_reporter_key}, test_reporter"
+ "continuing with empty config"
)
params_config = OmegaConf.create()
else:
params_config = config.params
return test_reporter_class(datamodules, params_config, dataset_type)
def _add_extra_args_for_dataloader(
dataset_instance: torch.utils.data.Dataset, other_args: Dict[str, Any] = None
) -> Dict[str, Any]:
from mmf.utils.general import get_batch_size
dataset_type = dataset_instance.dataset_type
if other_args["shuffle"] is None:
other_args["shuffle"] = False
if dataset_type != "test":
other_args["shuffle"] = True
# In distributed mode, we use DistributedSampler from PyTorch
if is_dist_initialized():
other_args["sampler"] = torch.utils.data.DistributedSampler(
dataset_instance, shuffle=other_args["shuffle"]
)
# Shuffle is mutually exclusive with sampler, let DistributedSampler
# take care of shuffle and pop from main args
other_args.pop("shuffle")
if is_xla():
other_args["sampler"] = torch.utils.data.DistributedSampler(
dataset_instance,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=other_args["shuffle"],
drop_last=True,
)
other_args.pop("shuffle")
if other_args["batch_size"] is None:
other_args["batch_size"] = get_batch_size()
return other_args
def build_optimizer(model, config):
optimizer_config = config.optimizer
if "type" not in optimizer_config:
raise ValueError(
"Optimizer attributes must have a 'type' key "
"specifying the type of optimizer. "
"(Custom or PyTorch, e.g. 'adam_w' or 'SGD')"
)
optimizer_type = optimizer_config.type
if "params" not in optimizer_config:
warnings.warn("optimizer attributes has no params defined, defaulting to {}.")
params = optimizer_config.get("params", {})
if hasattr(torch.optim, optimizer_type):
optimizer_class = getattr(torch.optim, optimizer_type)
else:
optimizer_class = registry.get_optimizer_class(optimizer_type)
if optimizer_class is None:
raise ValueError(
"No optimizer class of type {} present in "
"either torch or registered to registry"
)
parameters = get_optimizer_parameters(model, config)
if optimizer_config.get("enable_state_sharding", False):
# TODO(vedanuj): Remove once OSS is moved to PT upstream
try:
from fairscale.optim.oss import OSS
except ImportError:
print(
"Optimizer state sharding requires fairscale. "
+ "Install using pip install fairscale."
)
raise
assert (
is_dist_initialized()
), "Optimizer state sharding can only be used in distributed mode."
is_fp16 = config.get("training", {}).get("fp16", False)
optimizer = OSS(
params=parameters, optim=optimizer_class, broadcast_fp16=is_fp16, **params
)
else:
optimizer = optimizer_class(parameters, **params)
return optimizer
def build_lightning_optimizers(model, config):
optimizer = build_optimizer(model, config)
if config.training.lr_scheduler:
lr_scheduler = build_scheduler(optimizer, config)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
else:
return optimizer
def build_scheduler(optimizer, config):
scheduler_config = config.get("scheduler", {})
if "type" not in scheduler_config:
warnings.warn(
"No type for scheduler specified even though lr_scheduler is True, "
"setting default to 'Pythia'"
)
scheduler_type = scheduler_config.get("type", "pythia")
if "params" not in scheduler_config:
warnings.warn("scheduler attributes has no params defined, defaulting to {}.")
params = scheduler_config.get("params", {})
scheduler_class = registry.get_scheduler_class(scheduler_type)
scheduler = scheduler_class(optimizer, **params)
return scheduler
def build_classifier_layer(config, *args, **kwargs):
from mmf.modules.layers import ClassifierLayer
classifier = ClassifierLayer(config.type, *args, **config.params, **kwargs)
return classifier.module
def build_text_encoder(config, *args, **kwargs):
"""Deprecated, please do not use"""
try:
from mmf.modules.fb.encoders import TextEncoderFactory
except ImportError:
from mmf.modules.encoders import TextEncoderFactory
text_encoder = TextEncoderFactory(config, *args, **kwargs)
return text_encoder.module
def build_image_encoder(config, direct_features=False, **kwargs):
"""Deprecated, please do not use"""
from mmf.modules.encoders import ImageEncoderFactory, ImageFeatureEncoderFactory
if direct_features:
module = ImageFeatureEncoderFactory(config)
else:
module = ImageEncoderFactory(config)
return module.module
def build_encoder(config: Union[DictConfig, "mmf.modules.encoders.Encoder.Config"]):
from mmf.modules.encoders import Encoder
# If it is not an OmegaConf object, create the object
if not isinstance(config, DictConfig) and isinstance(config, Encoder.Config):
config = OmegaConf.structured(config)
if "type" in config:
# Support config initialization in form of
# encoder:
# type: identity # noqa
# params:
# in_dim: 256
name = config.type
if isinstance(name, Enum):
name = name.value
params = config.get("params", None)
else:
# Structured Config support
name = config.name
params = config
encoder_cls = registry.get_encoder_class(name)
# If params were not passed, try generating them from encoder
# class's default config
if params is None:
params = OmegaConf.structured(getattr(encoder_cls, "Config", {}))
return encoder_cls(params)
def build_processors(
processors_config: DictConfig, registry_key: str = None, *args, **kwargs
) -> ProcessorDict:
"""Given a processor config, builds the processors present and returns back
a dict containing processors mapped to keys as per the config
Args:
processors_config (omegaconf.DictConfig): OmegaConf DictConfig describing
the parameters and type of each processor passed here
registry_key (str, optional): If passed, function would look into registry for
this particular key and return it back. .format with processor_key will
be called on this string. Defaults to None.
Returns:
ProcessorDict: Dictionary containing key to
processor mapping
"""
from mmf.datasets.processors.processors import Processor
processor_dict = {}
for processor_key, processor_params in processors_config.items():
if not processor_params:
continue
processor_instance = None
if registry_key is not None:
full_key = registry_key.format(processor_key)
processor_instance = registry.get(full_key, no_warning=True)
if processor_instance is None:
processor_instance = Processor(processor_params, *args, **kwargs)
# We don't register back here as in case of hub interface, we
# want the processors to be instantiate every time. BaseDataset
# can register at its own end
processor_dict[processor_key] = processor_instance
return processor_dict
def build_iteration_strategy(
config: DictConfig,
dataloaders: Dict[str, torch.utils.data.DataLoader],
*args,
**kwargs,
) -> IterationStrategy:
if not config.get("enabled", True):
return ConstantIterationStrategy.from_params(dataloaders, *args, **kwargs)
else:
assert (
"type" in config
), "multitasking config must define 'type' attribute if enabled"
# This assumes all dataloaders will have same dataset type
iteration_strategy_class = registry.get_iteration_strategy_class(config.type)
config = config.get("params", {})
# val and test splits won't be affected as test reporter iterates
# over the datasets one by one without using any iteration strategy
return iteration_strategy_class(config, dataloaders, *args, **kwargs)
def build_meters(run_type: str) -> List[Meter]:
train_meter, val_meter, test_meter = None, None, None
if "train" in run_type:
train_meter = Meter()
# val_meter used for validation after training loop
val_meter = Meter()
elif "val" in run_type or "inference" in run_type:
val_meter = Meter()
if "test" in run_type:
test_meter = Meter()
return train_meter, val_meter, test_meter