https://github.com/mackelab/IdentifyMechanisticModels_2020
Raw File
Tip revision: b93c90ec6156ae5f8afee6aaac7317373e9caf5e authored by Jan-Matthis on 19 October 2020, 16:00:39 UTC
Initial commit
Tip revision: b93c90e
00_neural_net_train.py
import numpy as np
import time
import sys
sys.path.append("model/setup")
sys.path.append("model/simulator")
sys.path.append("model/inference")
sys.path.append("model/visualization")
sys.path.append("model/utils")

import netio
import train_utils as tu
import dill as pickle
import delfi.distribution as dd
from delfi.generator import Default

params_filename = "train_net"
params = netio.load_setup(params_filename)
mlist = netio.create_simulators(params)
dimensions = np.sum(params.use_membrane) + 7


################################################################################
#                                   Load data                                  #
################################################################################

filedir = "results/31D_samples/pyloricsamples_31D_noNaN_3.npz"
pilot_data, _, params_mean, params_std = tu.load_trn_data_normalize(filedir, params)
print('We use', len(pilot_data[0]), 'training samples.')


################################################################################
#                                  Define prior                                #
################################################################################

# create standard prior for the actual conductances
prior = netio.create_prior(params, log=True)
prior_norm = dd.Uniform(lower=-np.sqrt(3)*np.ones(dimensions), upper=np.sqrt(3)*np.ones(dimensions))

################################################################################
#                                  Load pairs                                  #
################################################################################

xo_stats = np.load('results/31D_experimental/190807_summstats_prep845_082_0044.npz')['summ_stats']
obs = xo_stats


################################################################################
#                               Define ss and g                                #
################################################################################

# Using the 'PrinzStats' here, defined in summstats.py. Inferring the summstats
# only between t_burnin and t_burnin+t_window. There are 15 stats.
s = netio.create_summstats(params)

# create the generator object
g = Default(model=mlist[0], prior=prior_norm, summary=s)


################################################################################
#                                Run inference                                 #
################################################################################

from delfi.inference import APT

# create inference object
inf_SNPE = APT(generator=g,
               prior_norm=False,
               n_hiddens=params.n_hiddens,
               pilot_samples=pilot_data,
               obs=obs,
               verbose=False,
               density=params.density_type,
               seed=1)
print('---- Successfully created a SNPE-C object ----')

start_time = time.time()
log, train_data, posteriors = inf_SNPE.run(proposal='prior',
                                           n_train=params.n_train,
                                           n_rounds=params.n_rounds,
                                           epochs=params.n_epochs,
                                           silent_fail=False,
                                           train_on_all=True,
                                           reuse_prior_samples=True,
                                           val_frac=0.1,
                                           verbose=True)
print("Estimated time", time.time()-start_time, "seconds")


################################################################################
#                                  Save data                                   #
################################################################################

inf_SNPE.trn_datasets = []
# with open('results/31D_nets/191001_seed1_Exper11deg.pkl', 'wb') as file:
#     pickle.dump([inf_SNPE, log, params], file)
print('Success')
back to top