Revision b8f4591dc0e233539a169e7319a751f3dccd2503 authored by Vedanuj Goswami on 01 February 2021, 09:50:35 UTC, committed by Facebook GitHub Bot on 01 February 2021, 09:52:13 UTC
Summary:
Fixes evaluation loop after changes in https://github.com/facebookresearch/mmf/issues/747

Pull Request resolved: https://github.com/facebookresearch/mmf/pull/757

Reviewed By: ytsheng

Differential Revision: D26171041

Pulled By: vedanuj

fbshipit-source-id: f0fd24ef96ef54dd7ea17af5968e05f0301e64a3
1 parent 8237ed2
Raw File
transform.py
# Copyright (c) Facebook, Inc. and its affiliates.

from torch import Tensor


def transform_to_batch_sequence(tensor: Tensor) -> Tensor:
    if len(tensor.size()) == 2:
        return tensor
    else:
        assert len(tensor.size()) == 3
        return tensor.contiguous().view(-1, tensor.size(-1))


def transform_to_batch_sequence_dim(tensor: Tensor) -> Tensor:
    if len(tensor.size()) == 3:
        return tensor
    else:
        assert len(tensor.size()) == 4
        return tensor.contiguous().view(-1, tensor.size(-2), tensor.size(-1))
back to top