https://github.com/yburda/iwae
Tip revision: 78a6075484bd7dd6f47a4bb9880d2295bc7a5220 authored by yburda-MBP on 26 January 2016, 17:18:52 UTC
omniglot
omniglot
Tip revision: 78a6075
train.py
import theano
import theano.tensor as T
import progressbar
def train(model, dataset, optimizer, minibatch_size, n_epochs, srng, **kwargs):
print "training for {} epochs with {} learning rate".format(n_epochs, optimizer.learning_rate)
num_minibatches = dataset.get_n_examples('train') / minibatch_size
index = T.lscalar('i')
minibatch = dataset.minibatchIindex_minibatch_size(index, minibatch_size, srng=srng, subdataset='train')
grad = model.gradIminibatch_srng(minibatch, srng, **kwargs)
updates = optimizer.updatesIgrad_model(grad, model)
train_step = theano.function([index], None, updates=updates)
pbar = progressbar.ProgressBar(maxval=n_epochs*num_minibatches).start()
for j in xrange(n_epochs):
for i in xrange(num_minibatches):
train_step(i)
pbar.update(j*num_minibatches+i)
pbar.finish()
return model