Raw File
test_fp16.py
# Copyright (c) Facebook, Inc. and its affiliates.

import unittest

import torch
from tests.test_utils import SimpleModel, skip_if_no_cuda
from tests.trainers.test_training_loop import TrainerTrainingLoopMock


class SimpleModelWithFp16Assert(SimpleModel):
    def forward(self, sample_list):
        batch_tensor = sample_list[list(sample_list.keys())[0]]
        # Start should be fp32
        assert batch_tensor.dtype == torch.float32
        batch_tensor = self.classifier(batch_tensor)

        # In between operation should be fp16
        assert batch_tensor.dtype == torch.float16
        loss = torch.sum(batch_tensor)

        # Sum should convert it back to fp32
        assert loss.dtype == torch.float32

        model_output = {"losses": {"loss": loss}}
        return model_output


class MMFTrainerMock(TrainerTrainingLoopMock):
    def __init__(
        self, num_train_data, max_updates, max_epochs, device="cuda", fp16_model=False
    ):
        super().__init__(num_train_data, max_updates, max_epochs, fp16=True)
        self.device = torch.device(device)
        if fp16_model:
            assert (
                torch.cuda.is_available()
            ), "MMFTrainerMock fp16 requires cuda enabled"
            self.model = SimpleModelWithFp16Assert({"in_dim": 1})
            self.model.build()
            self.model = self.model.cuda()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-3)


class TestFp16(unittest.TestCase):
    @skip_if_no_cuda
    def test_fp16_works(self):
        trainer = MMFTrainerMock(100, 2, 0.04)
        trainer.load_fp16_scaler()
        self.assertTrue(isinstance(trainer.scaler, torch.cuda.amp.GradScaler))
        self.assertEqual(trainer.current_iteration, 0)
        trainer.training_loop()
        self.assertEqual(trainer.current_iteration, 4)

    @skip_if_no_cuda
    def test_fp16_values(self):
        trainer = MMFTrainerMock(100, 2, 0.04, fp16_model=True)
        trainer.load_fp16_scaler()
        trainer.training_loop()
back to top