https://github.com/facebookresearch/pythia
Raw File
Tip revision: 2192818943499b83c2b0702a01384ba5fe3b5f61 authored by sash on 30 June 2021, 18:34:36 UTC
[fix]: clean up lightning test code
Tip revision: 2192818
m4c_captioner.py
# Copyright (c) Facebook, Inc. and its affiliates.
from mmf.common.registry import registry
from mmf.models.m4c import M4C


@registry.register_model("m4c_captioner")
class M4CCaptioner(M4C):
    def __init__(self, config):
        super().__init__(config)
        self.remove_unk_in_pred = self.config.remove_unk_in_pred

    @classmethod
    def config_path(cls):
        return "configs/models/m4c_captioner/defaults.yaml"

    def _forward_output(self, sample_list, fwd_results):
        super()._forward_output(sample_list, fwd_results)

        if self.remove_unk_in_pred:
            # avoid outputting <unk> in the generated captions
            fwd_results["scores"][..., self.answer_processor.UNK_IDX] = -1e10

        return fwd_results
back to top