https://github.com/facebookresearch/pythia
Raw File
Tip revision: dabf95f523cd07e93380c6931e5140ade0f50b2f authored by Sethu Sankaran on 26 October 2021, 19:18:43 UTC
Revert D30704069: [feat] Add a refiner head that can be used with MMFT
Tip revision: dabf95f
dataset.py
# Copyright (c) Facebook, Inc. and its affiliates.
import warnings
from typing import List

import torch
from mmf.common.sample import Sample
from omegaconf import DictConfig


def build_bbox_tensors(infos, max_length):
    num_bbox = min(max_length, len(infos))

    # After num_bbox, everything else should be zero
    coord_tensor = torch.zeros((max_length, 4), dtype=torch.float)
    width_tensor = torch.zeros(max_length, dtype=torch.float)
    height_tensor = torch.zeros(max_length, dtype=torch.float)
    bbox_types = ["xyxy"] * max_length

    infos = infos[:num_bbox]
    sample = Sample()

    for idx, info in enumerate(infos):
        bbox = info["bounding_box"]
        x = bbox.get("top_left_x", bbox["topLeftX"])
        y = bbox.get("top_left_y", bbox["topLeftY"])
        width = bbox["width"]
        height = bbox["height"]

        coord_tensor[idx][0] = x
        coord_tensor[idx][1] = y
        coord_tensor[idx][2] = x + width
        coord_tensor[idx][3] = y + height

        width_tensor[idx] = width
        height_tensor[idx] = height
    sample.coordinates = coord_tensor
    sample.width = width_tensor
    sample.height = height_tensor
    sample.bbox_types = bbox_types

    return sample


def build_dataset_from_multiple_imdbs(config, dataset_cls, dataset_type):
    from mmf.datasets.concat_dataset import MMFConcatDataset

    if dataset_type not in config.imdb_files:
        warnings.warn(
            "Dataset type {} is not present in "
            "imdb_files of dataset config. Returning None. "
            "This dataset won't be used.".format(dataset_type)
        )
        return None

    imdb_files = config["imdb_files"][dataset_type]

    datasets = []

    for imdb_idx in range(len(imdb_files)):
        dataset = dataset_cls(dataset_type, imdb_idx, config)
        datasets.append(dataset)

    dataset = MMFConcatDataset(datasets)

    return dataset


def dataset_list_from_config(config: DictConfig) -> List[str]:
    if "datasets" not in config:
        warnings.warn("No datasets attribute present. Setting default to vqa2.")
        datasets = "vqa2"
    else:
        datasets = config.datasets

    if type(datasets) == str:
        datasets = list(map(lambda x: x.strip(), datasets.split(",")))

    return datasets
back to top