https://github.com/duanders/mpi_learn
Raw File
Tip revision: 21bc15e292b736ee7e5baa81c6304b871d5fdbcd authored by vlimant on 23 April 2019, 12:02:58 UTC
Merge pull request #35 from vloncar/logging
Tip revision: 21bc15e
simple_train.py
import optparse

parser = optparse.OptionParser()
parser.add_option('--restart',action='store_true')
parser.add_option('--train',action='store_true')
parser.add_option('--test',action='store_true')
parser.add_option('--fresh',action='store_true')
parser.add_option('--fom', action='store_true')
parser.add_option('--tag',default='')
parser.add_option('--max',type='int', default = 0)
parser.add_option('--lr',type='float',default=0.0)
parser.add_option('--epochs', type='int', default=3)
parser.add_option('--cache',default=None)

(options,args) = parser.parse_args()

from mpi_learn.train.GanModel import GANModel
from mpi_learn.train.data import H5Data
import h5py
import setGPU
import json
import time


gan_args = {
    #'tell': False,
    #'reversedorder' : False,
    #'heavycheck' : False,
    #'show_values' : False,
    #'gen_bn' : True,
    #'checkpoint' : False,
    #'onepass' : False,
    #'show_loss' : True,
    #'with_fixed_disc' : True ## could switch back to False and check
    }
    
gm = GANModel(**gan_args)

restart = options.restart
fresh = options.fresh
tag = (options.tag+'_') if options.tag else ''
lr = options.lr

if restart:
    tag+='reload_'
    print ("Reloading")
    if lr:
        gm.compile( prop = False, lr=lr)
        tag+='sgd%s_'%lr
    else:
        gm.compile( prop = True)
        tag+='rmsprop_'

    ## start from an exiting model
    gm.generator.load_weights('FullRunApr3/simple_generator.h5')
    gm.discriminator.load_weights('FullRunApr3/simple_discriminator.h5')
    gm.combined.load_weights('FullRunApr3/simple_combined.h5')
else:
    if lr:
        gm.compile(prop = False, lr=lr)
        tag+='sgd%s_'%lr
    else:
        gm.compile()
        tag+='rmsprop_'

    if not fresh:
        try:
            gm.generator.load_weights('simple_generator.h5')
            gm.discriminator.load_weights('simple_discriminator.h5')
        except:
            print ("fresh weights")
    else:
        tag+='fresh_'

print (tag,"is the option")

files = list(filter(None,open('train_3d.list').read().split('\n')))
data = H5Data( batch_size = 100,
               cache = options.cache,
               preloading=0,
               features_name='X', labels_name='y')
data.set_file_names(files)
"""

if options.inmem:
    import os
    relocated = []
    os.system('mkdir /dev/shm/vlimant/')
    for fn in files:
        relocate = '/dev/shm/vlimant/'+fn.split('/')[-1]
        if not os.path.isfile( relocate ):
            print ("copying %s to %s"%( fn , relocate))
            if os.system('cp %s %s'%( fn ,relocate))==0:
                relocated.append( relocate )
    files = relocated
"""
history = {}
thistory = {}
fhistory = {}
etimes=[]
start = time.mktime(time.gmtime())


train_me = options.train
over_test= options.test
max_batch = options.max
ibatch=0
def dump():
    open('simple_train_%s.json'%tag,'w').write(json.dumps(
        {
            'h':history,
            'th':thistory,
            'fh':fhistory,
            'et':etimes,
            } ))    

nepochs = options.epochs

histories={}
for e in range(nepochs):
    history[e] = []
    thistory[e] = []
    fhistory[e] = []
    e_start = time.mktime(time.gmtime())
    for sub_X,sub_Y in data.generate_data():
        ibatch+=1
        #print (ibatch,ibatch>max_batch,max_batch)
        if over_test or not train_me:
            t_losses = gm.test_on_batch(sub_X,sub_Y)
            l = gm.get_logs( t_losses  ,val=True)
            gm.update_history( l , histories)
            t_losses = [list(map(float,l)) for l in t_losses]
            thistory[e].append( t_losses )
        if train_me:
            losses = gm.train_on_batch(sub_X,sub_Y)
            l = gm.get_logs( losses )
            gm.update_history( l , histories)                
            losses = [list(map(float,l)) for l in losses]
            history[e].append( losses )
        if max_batch and ibatch>max_batch:
            break
            
        #if options.fom:
        #    fom = gm.figure_of_merit()
        #    print ("figure of merit",fom)

    if options.fom:
        fom = gm.figure_of_merit()
        print ("figure of merit",fom)
        fhistory[e].append( fom )
    gm.generator.save_weights('simple_generator_%s.h5'%tag)
    gm.discriminator.save_weights('simple_discriminator_%s.h5'%tag)
    gm.combined.save_weights('simple_combined_%s.h5'%tag)
    dump()
    if max_batch and ibatch>max_batch:
        break
        
    e_stop = time.mktime(time.gmtime())
    print (e_stop - e_start,"[s] for epoch",e)
    etimes.append( e_stop - e_start)
dump()
back to top