https://github.com/sriniiyer/concode
Tip revision: 864e30807f6988731ac3b4b98af6562c18bb42ff authored by Srinivasan Iyer on 27 January 2021, 00:44:34 UTC
Merge pull request #2 from sriniiyer/add-license-1
Merge pull request #2 from sriniiyer/add-license-1
Tip revision: 864e308
predict.ipy
import argparse
import os
import json
parser = argparse.ArgumentParser(description='run.ipy')
# Prediction
parser.add_argument('-start', type=int, default=5,
help='Epoch to start prediction')
parser.add_argument('-end', type=int, default=30,
help='Epoch to end prediction')
parser.add_argument('-beam', type=int, default=3,
help='Beam size')
parser.add_argument('-tgt_len', type=int, default=3,
help='Beam size')
parser.add_argument('-models_dir', type=str, default='',
help='Models directory.')
parser.add_argument('-test_file', type=str, default='',
help='Test data directory.')
parser.add_argument('-additional', type=str, default='',
help='Any additional flags.')
opt = parser.parse_args()
try:
os.makedirs(opt.models_dir + '/preds/')
except:
pass
# We need a text file with the outputs to compute BLEU.
# so extract it out of the json file
test_dataset = json.loads(open(opt.test_file, 'r').read())
test_dataset_targets = open('/tmp/test.code', 'w')
for example in test_dataset:
test_dataset_targets.write(' '.join(example['code']).replace('concodeclass_', '').replace('concodefunc_', '') + '\n')
test_dataset_targets.close()
best_bleu, best_exact = (0, 0, 0), (0, 0, 0)
for i in range(opt.start, opt.end + 1):
fname = !ls {opt.models_dir}/model_acc_*e{i}.pt
f = os.path.basename(fname[0])
print(f)
!rm {opt.models_dir}/preds/{f}.nl.prediction*
# Prod is just a dummy here
!python translate.py -beam_size {opt.beam} -gpu 0 -model {fname[0]} -src {opt.test_file} -output {opt.models_dir}/preds/{f}.nl.prediction -max_sent_length {opt.tgt_len} -replace_unk -batch_size 1 -trunc 2000 {opt.additional}
bleu = !perl tools/multi-bleu.perl -lc /tmp/test.code < {opt.models_dir}/preds/{f}.nl.prediction
print(bleu)
bleu_score = float(bleu[0].split(',')[0])
exact = !python tools/exact.py /tmp/test.code < {opt.models_dir}/preds/{f}.nl.prediction
print(exact)
exact_score = float(exact[0])
if bleu_score > best_bleu[0]:
best_bleu = (bleu_score, exact_score, i)
if exact_score > best_exact[1]:
best_exact = (bleu_score, exact_score, i)
print ('Best BLEU so far is: {} - Exact is {} - on epoch {}'.format(*best_bleu))
print ('BLEU is {} - Best Exact so far: is {} - on epoch {}'.format(*best_exact))