https://github.com/sordonia/hed-qs
Raw File
Tip revision: b217a94387eca1e37975d7dd770d94981ec7eac1 authored by Alessandro Sordoni on 07 April 2017, 14:07:22 UTC
Update README.md
Tip revision: b217a94
train.py
#!/usr/bin/env python

from data_iterator import *
from state import *
from session_encdec import *
from utils import *

import time
import traceback
import os.path
import sys
import argparse
import cPickle
import logging
import search
import pprint
import numpy
import collections
import signal

class Unbuffered:
    def __init__(self, stream):
        self.stream = stream

    def write(self, data):
        self.stream.write(data)
        self.stream.flush()

    def __getattr__(self, attr):
        return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)
logger = logging.getLogger(__name__)

### Unique RUN_ID for this execution
RUN_ID = str(time.time())
### Additional measures can be set here
measures = ["train", "valid"]

def init_timings():
    timings = {}
    for m in measures:
        timings[m] = []
    return timings

def save(model, timings):
    print "Saving the model..."

    # ignore keyboard interrupt while saving
    start = time.time()
    s = signal.signal(signal.SIGINT, signal.SIG_IGN)
    model.save(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + '_model.npz')
    cPickle.dump(model.state, open(model.state['save_dir'] + '/' +  model.state['run_id'] + "_" + model.state['prefix'] + '_state.pkl', 'w'))
    numpy.savez(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + '_timing.npz', **timings)
    signal.signal(signal.SIGINT, s)
    print "Model saved, took {}".format(time.time() - start)

def load(model, filename):
    print "Loading the model..."

    # ignore keyboard interrupt while saving
    start = time.time()
    s = signal.signal(signal.SIGINT, signal.SIG_IGN)
    model.load(filename)
    signal.signal(signal.SIGINT, s)

    print "Model loaded, took {}".format(time.time() - start)

def main(args):
    logging.basicConfig(
        level = logging.INFO,
        format = "%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state = eval(args.prototype)()
    timings = init_timings()

    if args.resume != "":
        logger.debug("Resuming %s" % args.resume)
        state_file = args.resume + '_state.pkl'
        timings_file = args.resume + '_timing.npz'
        if os.path.isfile(state_file) and os.path.isfile(timings_file):
            logger.debug("Loading previous state")
            state = cPickle.load(open(state_file, 'r'))
            timings = dict(numpy.load(open(timings_file, 'r')))
            for x, y in timings.items():
                timings[x] = list(y)
        else:
            raise Exception("Cannot resume, cannot find files!")

    logger.info("State:\n{}".format(pprint.pformat(state)))
    logger.info("Timings:\n{}".format(pprint.pformat(timings)))

    model = SessionEncoderDecoder(state)
    rng = model.rng

    if args.resume != "":
        filename = args.resume + '_model.npz'
        if os.path.isfile(filename):
            logger.info("Loading previous model")
            load(model, filename)
        else:
            raise Exception("Cannot resume, cannot find model file!")
    else:
        # assign new run_id key
        model.state['run_id'] = RUN_ID

    logger.info("Compile trainer")
    train_batch = model.build_train_function()
    eval_batch = model.build_eval_function()
    random_sampler = search.RandomSampler(model)

    logger.info("Load data")
    train_data, valid_data = get_batch_iterator(rng, state)
    train_data.start()

    # Start looping through the dataset
    step = 0
    patience = state['patience']
    start_time = time.time()
    train_cost = 0
    train_done = 0
    ex_done = 0
    while step < state['loop_iters'] and patience >= 0:
        # Sample stuff
        if step % 200 == 0:
            for param in model.params:
                print "%s = %.4f" % (param.name,
                                     numpy.sum(param.get_value() ** 2) ** 0.5)
            samples, costs = random_sampler.sample([[]], n_samples=1, n_turns=3)
            print "Sampled : {}".format(samples[0])

        # Training phase
        batch = train_data.next()
        # Train finished
        if not batch:
            # Restart training
            logger.debug("Got None...")
            break
        c = train_batch(
            batch['x'], batch['y'], batch['max_length'], batch['x_mask'])
        if numpy.isinf(c) or numpy.isnan(c):
            logger.warn("Got NaN cost .. skipping")
            continue

        train_cost += c
        train_done += batch['num_preds']

        this_time = time.time()
        if step % state['train_freq'] == 0:
            elapsed = this_time - start_time
            h, m, s = ConvertTimedelta(this_time - start_time)
            print ".. %.2d:%.2d:%.2d %4d mb # %d bs %d maxl %d acc_cost = %.4f" % (h, m, s,\
                                                                             state['time_stop'] - (time.time() - start_time)/60.,\
                                                                             step, \
                                                                             batch['x'].shape[1], \
                                                                             batch['max_length'], \
                                                                             float(train_cost/train_done))
        if valid_data is not None and\
            step % state['valid_freq'] == 0 and step > 1:
                valid_data.start()
                valid_cost = 0
                valid_done = 0
                logger.debug("[VALIDATION START]")
                while True:
                    batch = valid_data.next()
                    # Train finished
                    if not batch:
                        break
                    if numpy.isinf(c) or numpy.isnan(c):
                        continue
                    c = eval_batch(
                        batch['x'], batch['y'], batch['max_length'], batch['x_mask'])
                    valid_cost += c
                    valid_done += batch['num_preds']
                logger.debug("[VALIDATION END]")
                valid_cost /= valid_done
                if len(timings["valid"]) == 0 or valid_cost < numpy.min(numpy.array(timings["valid"])):
                    patience = state['patience']
                    # Saving model if decrease in validation cost
                    save(model, timings)
                elif valid_cost >= timings["valid"][-1] * state['cost_threshold']:
                    patience -= 1

                print "** validation error = %.4f, patience = %d" % (float(valid_cost), patience)
                timings["train"].append(train_cost/train_done)
                timings["valid"].append(valid_cost)

                # Reset train cost and train done
                train_cost = 0
                train_done = 0
        step += 1
    logger.debug("All done, exiting...")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--resume", type=str, default="", help="Resume training from that state")
    parser.add_argument("--prototype", type=str, help="Use the prototype", default='prototype_state')

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    main(args)
back to top