https://github.com/yburda/iwae
Raw File
Tip revision: 78a6075484bd7dd6f47a4bb9880d2295bc7a5220 authored by yburda-MBP on 26 January 2016, 17:18:52 UTC
omniglot
Tip revision: 78a6075
train.py
import theano
import theano.tensor as T

import progressbar


def train(model, dataset, optimizer, minibatch_size, n_epochs, srng, **kwargs):
    print "training for {} epochs with {} learning rate".format(n_epochs, optimizer.learning_rate)
    num_minibatches = dataset.get_n_examples('train') / minibatch_size

    index = T.lscalar('i')
    minibatch = dataset.minibatchIindex_minibatch_size(index, minibatch_size, srng=srng, subdataset='train')

    grad = model.gradIminibatch_srng(minibatch, srng, **kwargs)
    updates = optimizer.updatesIgrad_model(grad, model)

    train_step = theano.function([index], None, updates=updates)

    pbar = progressbar.ProgressBar(maxval=n_epochs*num_minibatches).start()
    for j in xrange(n_epochs):
        for i in xrange(num_minibatches):
            train_step(i)
            pbar.update(j*num_minibatches+i)
    pbar.finish()
    return model
back to top