https://github.com/sriniiyer/concode
Raw File
Tip revision: 864e30807f6988731ac3b4b98af6562c18bb42ff authored by Srinivasan Iyer on 27 January 2021, 00:44:34 UTC
Merge pull request #2 from sriniiyer/add-license-1
Tip revision: 864e308
predict.ipy
import argparse
import os
import json

parser = argparse.ArgumentParser(description='run.ipy')

# Prediction
parser.add_argument('-start', type=int, default=5,
                    help='Epoch to start prediction')
parser.add_argument('-end', type=int, default=30,
                    help='Epoch to end prediction')
parser.add_argument('-beam', type=int, default=3,
                    help='Beam size')
parser.add_argument('-tgt_len', type=int, default=3,
                    help='Beam size')

parser.add_argument('-models_dir', type=str, default='',
                    help='Models directory.')
parser.add_argument('-test_file', type=str, default='',
                    help='Test data directory.')
parser.add_argument('-additional', type=str, default='',
                    help='Any additional flags.')

opt = parser.parse_args()

try:
  os.makedirs(opt.models_dir + '/preds/')
except:
  pass

# We need a text file with the outputs to compute BLEU.
# so extract it out of the json file
test_dataset = json.loads(open(opt.test_file, 'r').read())
test_dataset_targets = open('/tmp/test.code', 'w')
for example in test_dataset:
  test_dataset_targets.write(' '.join(example['code']).replace('concodeclass_', '').replace('concodefunc_', '') + '\n')
test_dataset_targets.close()

best_bleu, best_exact = (0, 0, 0), (0, 0, 0)
for i in range(opt.start, opt.end + 1):
  fname = !ls {opt.models_dir}/model_acc_*e{i}.pt
  f = os.path.basename(fname[0])
  print(f)
  !rm {opt.models_dir}/preds/{f}.nl.prediction*

  # Prod is just a dummy here
  !python translate.py -beam_size {opt.beam} -gpu 0 -model {fname[0]} -src {opt.test_file} -output {opt.models_dir}/preds/{f}.nl.prediction -max_sent_length {opt.tgt_len} -replace_unk -batch_size 1 -trunc 2000 {opt.additional}

  bleu = !perl tools/multi-bleu.perl -lc /tmp/test.code < {opt.models_dir}/preds/{f}.nl.prediction
  print(bleu)
  bleu_score = float(bleu[0].split(',')[0])

  exact = !python tools/exact.py /tmp/test.code < {opt.models_dir}/preds/{f}.nl.prediction
  print(exact)
  exact_score = float(exact[0])

  if bleu_score > best_bleu[0]:
    best_bleu = (bleu_score, exact_score, i)
  if exact_score > best_exact[1]:
    best_exact = (bleu_score, exact_score, i)
  print ('Best BLEU so far is: {} - Exact is {} - on epoch {}'.format(*best_bleu))
  print ('BLEU is {} - Best Exact so far: is {} - on epoch {}'.format(*best_exact))
back to top