https://github.com/jacopobono/learning_cognitive_maps_code
Tip revision: d86b262545547353c7050bbc2d476c2f4a297989 authored by Jacopo Bono on 26 January 2023, 19:15:51 UTC
somewhat cleaner scripts
somewhat cleaner scripts
Tip revision: d86b262
supplementary_random_initial_weights.py
import numpy as np
import pylab as plt
from source.utils import parameters_linear_track, run_td_lambda_new_continuousTime, run_MC
import pickle
import datetime, os
from multiprocessing import Pool, cpu_count
from schema import Schema, Optional, And
from types import SimpleNamespace
from matplotlib import cm
from matplotlib.colors import ListedColormap
def fun_to_run(traj, nr_states=4):
traj,nr,ini,path,T_lists,mode,store = traj
# params, gamma, eta, lambda_var = fetch_parameters_new()
params, gamma, eta, lambda_var = parameters_linear_track(nr_states)
params['trajectories'] = traj
params = update_params(mode, params, T_lists, gamma, eta)
if mode == 'TD':
M_stdp = run_td_lambda_new_continuousTime(**params)
if mode == 'MC':
M_stdp = run_MC(**params)
if store:
with open(path+f'/{mode}_traj_{ini+nr}.pkl','wb') as f:
pickle.dump(M_stdp,f)
return M_stdp
def update_params(mode, params, T_lists, gamma, eta):
params['ini'] = 'random'
if mode == 'TD':
params['offline'] = False
params['T_lists'] = T_lists
del params['T']
if mode == 'MC':
temporal_same = 2
temporal_between = -np.log(gamma)*params['tau_plus']
params['temporal_same'] = temporal_same
params['temporal_between'] = temporal_between
params['eta_stdp'] = eta/(params['A_plus']*np.exp(-temporal_same/params['tau_plus']))
params['spike_noise_prob'] = 0.2
params['pre_offset'] = 5
# eta_replay = params['eta_stdp'] *params['A_plus']*np.exp(-temporal_same/params['tau_plus'])
# gamma_replay = np.exp(-temporal_between/params['tau_plus'])
return params
def validate_conf(conf: dict) -> SimpleNamespace :
"""
Validate the hyperparameters in the dictionary.
Parameters
----------
conf : dict
Dictionary containing hyperparemeters to run the simulation.
Returns
-------
SimpleNamespace
Namespace containing the validated hyperparameters.
"""
valid_conf = Schema({
Optional('state_3_T', default=100): And(int, lambda s: s in [100,200]),
Optional('experiment_folder', default="simulation"): str,
Optional('initial_trial_nr', default=1): And(int, lambda s: s>0),
Optional('final_trial_nr', default=2): And(int, lambda s: s>0),
Optional('nr_epochs', default=2): And(int, lambda s: s>0),
# Optional('nr_states', default=10): And(int, lambda s: s>2),
Optional('cpu_percentage', default=0.9): And(float, lambda s: s>0,
lambda s: s<=1),
Optional('multiprocessing', default=False): bool,
Optional('store_data', default=False): bool,
Optional('mode', default="MC"): And(str, lambda s: s in ['TD', 'MC']),
}).validate(conf)
assert valid_conf['final_trial_nr'] > valid_conf['initial_trial_nr'], 'ERROR: Number of final trial should be larger than number of initial trial!'
return SimpleNamespace(**valid_conf)
def create_path(conf: dict):
"""
Create experiment path.
Parameters
----------
conf: dict
Dict of the yaml configuration file.
logg: LOGG
Logger object.
Returns
-------
newpath: str
Path to experiment folder.
"""
# Create experiment folder
date_time = str(datetime.datetime.now()).split('.')[0].replace('-','_').replace(':','_').replace(' ','_')
newpath = os.path.join(conf.experiment_folder, date_time)
newpath = f'{conf.mode}_runs/'+newpath
# Create storage folder
if conf.store_data:
os.makedirs(newpath)
print('Folder {} created!'.format(newpath))
return newpath
def run_simulation(conf):
"""
Run a simulation on the linear track.
Parameters
----------
"""
absorbing_state = False
# validate cofiguration
conf = validate_conf(conf)
# Define total number of trials
tot_trials = conf.final_trial_nr-conf.initial_trial_nr
# Define trajectory and time per state
if absorbing_state:
nr_states = 5
traj = np.repeat([np.repeat([list(range(nr_states)) + [nr_states-1]], conf.nr_epochs, axis=0)], tot_trials, axis=0)
T_list = [[100]*nr_states + [100]]*conf.nr_epochs
else:
nr_states = 4
traj = np.repeat([np.repeat([list(range(nr_states))], conf.nr_epochs, axis=0)], tot_trials, axis=0)
T_list = [[100]*nr_states ]*conf.nr_epochs
# Define function to run and storage path
newpath = create_path(conf)
# Run in parallel if set
if conf.multiprocessing:
cc = int(conf.cpu_percentage*cpu_count())
print(f'Using max {cc} cpus...')
pool = Pool(cc)
output = pool.map(lambda x: fun_to_run(x, nr_states=nr_states),
zip(traj,range(len(traj)),[conf.initial_trial_nr]*len(traj),
[newpath]*len(traj), [np.array(T_list)]*len(traj),
[conf.mode]*len(traj),
[conf.store_data]*len(traj),
) )
outputs=np.array(output)
else:
output = []
for idx,tt in enumerate(traj):
output.append(fun_to_run((tt,idx,conf.initial_trial_nr,
newpath, np.array(T_list),conf.mode,
conf.store_data,
),
nr_states=nr_states)
)
outputs=np.array(output)
# cumulative lengths
l = [np.cumsum([len(x) for x in t]) for t in traj]
# params, gamma_var, _, _ = fetch_parameters_new()
params, gamma_var, _, _ = parameters_linear_track(nr_states)
if conf.mode == 'TD':
w_stdps = np.array([[outputs[i][l_i] for l_i in l[i]] for i in range(len(outputs))])
w_stdps2 = np.mean(np.sum(w_stdps[:,:,:,:,:,:],axis=-2)/params['N_pre'],axis=-1)
# w_stdps2 = outputs #np.expand_dims(outputs, axis=0)
elif conf.mode == 'MC':
w_stdps2 = outputs #np.array(outputs)
#############################
# STORING AND PLOTTING
#############################
if conf.store_data:
pickle.dump(w_stdps2, open(newpath + '/rebuttal_random_init_w_stdps2.pkl','wb'))
# init = np.tile(np.expand_dims(np.expand_dims(np.eye(4),0),0),[tot_trials,1,1,1])
# w_stdps2 = np.concatenate([init,w_stdps2],axis=1)
fs = 16
top = cm.get_cmap('Greys', 128)
bottom = cm.get_cmap('magma', 128*2)
newcolors = np.vstack((top(np.linspace(0, 1, 128)),
bottom(np.linspace(0, 1, 128*2))[:128]))
newcmp = ListedColormap(newcolors, name='OrangeBlue')
fig3 = plt.figure()
plt.imshow(w_stdps2[:,-2,:,:].mean(axis=0),cmap=newcmp,vmin=0,vmax=2)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=fs)
plt.yticks(ticks=[0,1,2,3],labels=[1,2,3,4],fontsize=fs)
plt.xticks(ticks=[0,1,2,3],labels=[1,2,3,4],fontsize=fs)
plt.xlabel('Future state (CA1)',fontsize=fs)
plt.ylabel('Current state (CA3)',fontsize=fs)
if conf.store_data:
fig3.savefig(newpath+'/fig_sr.eps',bbox_inches='tight',format='eps')
pre_state = 2
post_state = 3
fig4 = plt.figure()
plt.plot(w_stdps2[:,:,pre_state,post_state].mean(axis=0))
cbar.ax.tick_params(labelsize=fs)
plt.yticks(fontsize=fs)
plt.xticks(fontsize=fs)
plt.xlabel('Time',fontsize=fs)
plt.ylabel('Weight',fontsize=fs)
if conf.store_data:
fig4.savefig(newpath+f'/fig_{pre_state}_{post_state}.eps',bbox_inches='tight',format='eps')
if __name__ == '__main__':
conf_supp_RandomInit = {
'state_3_T': 100, # time in state 3
'experiment_folder': "supp_RandomInit", # folder to store
'initial_trial_nr': 1, # initial trial number
'final_trial_nr': 2, # final trial number
'nr_epochs': 300, # number of epochs per trial
'cpu_percentage': 0.9, # percentage of cpus to use when running in parallel
'multiprocessing': False, # whether to run in parallel using multiprocessing
'store_data': False, # whether to store the experiment data
'mode': "MC", # mode, can be MC, TD
}
run_simulation(conf_supp_RandomInit)