https://github.com/facebookresearch/pythia
Tip revision: 48c0d58529a2ea71ac03b13bc9dcd7a2d6aad388 authored by omkar on 01 July 2020, 17:58:35 UTC
Working training
Working training
Tip revision: 48c0d58
process_answers.py
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import json
import os
from mmf.datasets.processors.processors import EvalAIAnswerProcessor
def get_score(occurences):
if occurences == 0:
return 0
elif occurences == 1:
return 0.3
elif occurences == 2:
return 0.6
elif occurences == 3:
return 0.9
else:
return 1
def multiple_replace(text, wordDict):
for key in wordDict:
text = text.replace(key, wordDict[key])
return text
def filter_answers(answers_dset, min_occurence):
"""This will change the answer to preprocessed version
"""
occurence = {}
answer_list = []
evalai_answer_processor = EvalAIAnswerProcessor()
for ans_entry in answers_dset:
gtruth = ans_entry["multiple_choice_answer"]
gtruth = evalai_answer_processor(gtruth)
if gtruth not in occurence:
occurence[gtruth] = set()
occurence[gtruth].add(ans_entry["question_id"])
for answer in occurence.keys():
if len(occurence[answer]) >= min_occurence:
answer_list.append(answer)
print(
"Num of answers that appear >= %d times: %d" % (min_occurence, len(answer_list))
)
return answer_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--annotation_file",
type=str,
required=True,
help="input train annotationjson file",
)
parser.add_argument(
"--val_annotation_file",
type=str,
required=False,
help="input val annotation json file",
)
parser.add_argument(
"--out_dir",
type=str,
default="./",
help="output directory, default is current directory",
)
parser.add_argument(
"--min_freq",
type=int,
default=0,
help="the minimum times of answer occurrence \
to be included in vocabulary, default 0",
)
args = parser.parse_args()
train_annotation_file = args.annotation_file
out_dir = args.out_dir
min_freq = args.min_freq
answer_file_name = "answers_vqa.txt"
os.makedirs(out_dir, exist_ok=True)
train_answers = json.load(open(train_annotation_file))["annotations"]
answers = train_answers
if args.val_annotation_file is not None:
val_annotation_file = args.val_annotation_file
val_answers = json.load(open(val_annotation_file))["annotations"]
answers = train_answers + val_answers
answer_list = filter_answers(answers, min_freq)
answer_list = [t.strip() for t in answer_list if len(t.strip()) > 0]
answer_list.sort()
if "<unk>" not in answer_list:
answer_list = ["<unk>"] + answer_list
answer_file = os.path.join(out_dir, answer_file_name)
with open(answer_file, "w") as f:
f.writelines([w + "\n" for w in answer_list])