https://github.com/jacopobono/learning_cognitive_maps_code
Tip revision: d86b262545547353c7050bbc2d476c2f4a297989 authored by Jacopo Bono on 26 January 2023, 19:15:51 UTC
somewhat cleaner scripts
somewhat cleaner scripts
Tip revision: d86b262
supplementary_2d_environment.py
import numpy as np
import pylab as plt
from source.utils import parameters_linear_track, run_td_lambda_new_continuousTime, run_MC
import pickle
import datetime, os
from multiprocessing import Pool, cpu_count
from schema import Schema, Optional, And
from types import SimpleNamespace
from matplotlib import cm
from matplotlib.colors import ListedColormap
def fun_to_run(traj):
traj,nr,ini,path,T_lists,mode,store,nr_states = traj
# params, gamma, eta, lambda_var = fetch_parameters_new()
params, gamma, eta, lambda_var = parameters_linear_track(nr_states)
params['trajectories'] = traj
params = update_params(mode, params, T_lists, gamma, eta)
if mode == 'TD':
M_stdp = run_td_lambda_new_continuousTime(**params)
if mode == 'MC':
M_stdp = run_MC(**params)
if store:
with open(path+f'/{mode}_traj_{ini+nr}.pkl','wb') as f:
pickle.dump(M_stdp,f)
return M_stdp
def update_params(mode, params, T_lists, gamma, eta):
if mode == 'TD':
params['offline'] = False
params['T_lists'] = T_lists
del params['T']
if mode == 'MC':
temporal_same = 2
temporal_between = -np.log(gamma)*params['tau_plus']
params['temporal_same'] = temporal_same
params['temporal_between'] = temporal_between
params['eta_stdp'] = eta/(params['A_plus']*np.exp(-temporal_same/params['tau_plus']))
params['spike_noise_prob'] = 0.2
params['pre_offset'] = 5
# eta_replay = params['eta_stdp'] *params['A_plus']*np.exp(-temporal_same/params['tau_plus'])
# gamma_replay = np.exp(-temporal_between/params['tau_plus'])
return params
def validate_conf(conf: dict) -> SimpleNamespace :
"""
Validate the hyperparameters in the dictionary.
Parameters
----------
conf : dict
Dictionary containing hyperparemeters to run the simulation.
Returns
-------
SimpleNamespace
Namespace containing the validated hyperparameters.
"""
valid_conf = Schema({
Optional('experiment_folder', default="simulation"): str,
Optional('initial_trial_nr', default=1): And(int, lambda s: s>0),
Optional('final_trial_nr', default=2): And(int, lambda s: s>0),
Optional('nr_epochs', default=2): And(int, lambda s: s>0),
# Optional('nr_states', default=10): And(int, lambda s: s>2),
Optional('cpu_percentage', default=0.9): And(float, lambda s: s>0,
lambda s: s<=1),
Optional('multiprocessing', default=False): bool,
Optional('store_data', default=False): bool,
Optional('mode', default="MC"): And(str, lambda s: s in ['TD', 'MC', 'MX']),
}).validate(conf)
assert valid_conf['final_trial_nr'] > valid_conf['initial_trial_nr'], 'ERROR: Number of final trial should be larger than number of initial trial!'
return SimpleNamespace(**valid_conf)
def create_path(conf: dict):
"""
Create experiment path.
Parameters
----------
conf: dict
Dict of the yaml configuration file.
logg: LOGG
Logger object.
Returns
-------
newpath: str
Path to experiment folder.
"""
# Create experiment folder
date_time = str(datetime.datetime.now()).split('.')[0].replace('-','_').replace(':','_').replace(' ','_')
newpath = os.path.join(conf.experiment_folder, date_time)
newpath = f'{conf.mode}_runs/'+newpath
# Create storage folder
if conf.store_data:
os.makedirs(newpath)
print('Folder {} created!'.format(newpath))
return newpath
def generate_2d_trajectories(nr_trajectories):
init_state = 16
reward_state = 1
next_states = {1: [17],
2: [1,3,6],
3: [2,4,7],
4: [3,8],
5: [1,6,9],
6: [2,5,7,10],
7: [3,6,8],
8: [4,7,12],
9: [5,10,13],
10: [6,9,14],
12: [8,16],
13: [9,14],
14: [10,13,15],
15: [14,16],
16: [12,15],
}
all_trajectories = []
for _ in range(nr_trajectories):
traj = []
current_state = init_state
while current_state != reward_state:
traj.append(current_state-1)
current_state = np.random.choice(next_states[current_state])
traj.append(reward_state-1)
# traj.append
all_trajectories.append(traj)
return all_trajectories
def run_simulation(conf):
"""
Run a simulation on the linear track.
Parameters
----------
"""
# validate cofiguration
conf = validate_conf(conf)
# Define total number of trials
tot_trials = conf.final_trial_nr-conf.initial_trial_nr
# Define trajectory and time per state
nr_states=17
traj = [generate_2d_trajectories(conf.nr_epochs) for _ in range(tot_trials)]
T_list = [[[100 for _ in y] for y in x] for x in traj]
# Define function to run and storage path
newpath = create_path(conf)
# Run in parallel if set
if conf.multiprocessing:
cc = int(conf.cpu_percentage*cpu_count())
print(f'Using max {cc} cpus...')
pool = Pool(cc)
output = pool.map(fun_to_run, ## DON'T USE LAMBDA FUNCTION IN MULTIPROCESSING
zip(traj,range(len(traj)),[conf.initial_trial_nr]*len(traj),
[newpath]*len(traj), T_list,
[conf.mode]*len(traj),
[conf.store_data]*len(traj),
[nr_states]*len(traj),
)
)
outputs=np.array(output)
else:
output = []
for idx,tt in enumerate(traj):
output.append(fun_to_run((tt,idx,conf.initial_trial_nr,
newpath, np.array(T_list),conf.mode,
conf.store_data,
nr_states,
)
)
)
outputs=np.array(output)
# cumulative lengths
l = [np.cumsum([len(x) for x in t]) for t in traj]
# params, gamma_var, _, _ = fetch_parameters_new()
params, gamma_var, _, _ = parameters_linear_track(nr_states)
if conf.mode == 'TD':
w_stdps = np.array([[outputs[i][l_i] for l_i in l[i]] for i in range(len(outputs))])
w_stdps2 = np.mean(np.sum(w_stdps[:,:,:,:,:,:],axis=-2)/params['N_pre'],axis=-1)
elif conf.mode == 'MC':
w_stdps2 = outputs #np.array(outputs)
#############################
# STORING AND PLOTTING
#############################
if conf.store_data:
pickle.dump(w_stdps2, open(newpath + f'/rebuttal_2d_env_{conf.mode}_w_stdps2.pkl','wb'))
init = np.tile(np.expand_dims(np.expand_dims(np.eye(nr_states),0),0),[len(traj),1,1,1])
w_stdps2 = np.concatenate([init,w_stdps2],axis=1)
fs = 16
top = cm.get_cmap('Greys', 128)
bottom = cm.get_cmap('magma', 128*2)
newcolors = np.vstack((top(np.linspace(0, 1, 128)),
bottom(np.linspace(0, 1, 128*2))[:128]))
newcmp = ListedColormap(newcolors, name='OrangeBlue')
fig3 = plt.figure()
plt.imshow(w_stdps2[:,-2,:-1,:-1].mean(axis=0),cmap=newcmp,vmin=0,vmax=2)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=fs)
tcks = list(range(nr_states-1))
plt.yticks(ticks=tcks,labels=[x+1 for x in tcks],fontsize=fs)
plt.xticks(ticks=tcks,labels=[x+1 for x in tcks],fontsize=fs)
plt.xlabel('Future state (CA1)',fontsize=fs)
plt.ylabel('Current state (CA3)',fontsize=fs)
plt.title(f'{conf.mode}',fontsize=fs)
if conf.store_data:
fig3.savefig(newpath+f'/fig_rebuttal_2d_env_{conf.mode}.eps',bbox_inches='tight',format='eps')
T_mat = np.array([[0]*16+[1],
[1/3, 0, 1/3, 0, 0, 1/3] + [0]*11,
[0,1/3, 0, 1/3, 0, 0, 1/3] + [0]*10,
[0,0,1/2, 0, 0, 0, 0, 1/2] + [0]*9,
[1/3] + [0]*4 + [1/3, 0, 0, 1/3] + [0]*8,
[0,1/4] + [0]*2 + [1/4, 0, 1/4, 0, 0, 1/4] + [0]*7,
[0,0,1/3, 0, 0, 1/3, 0, 1/3] + [0]*9,
[0,0,0,1/3, 0, 0, 1/3, 0, 0,0,0,1/3] + [0]*5,
[0,0,0,0,1/3, 0, 0, 0, 0, 1/3,0,0,1/3] + [0]*4,
[0,0,0,0,0,1/3, 0, 0, 1/3, 0, 0,0,0,1/3] + [0]*3,
[0]*16+[1],
[0]*7+[1/2]+[0]*7+[1/2]+[0],
[0]*8+[1/2]+[0]*4+[1/2]+[0]*3,
[0]*9+[1/3]+[0]*2+[1/3]+[0]*1+[1/3]+[0]*2,
[0]*13+[1/2]+[0]*1+[1/2]+[0],
[0]*11+[1/2]+[0]*2+[1/2]+[0]*2,
[0]*16+[1],
])
M_theor = np.linalg.inv(np.eye(nr_states-1)-gamma_var*T_mat[:-1, :-1])
fig4 = plt.figure()
plt.imshow(M_theor,cmap=newcmp,vmin=0,vmax=2)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=fs)
tcks = list(range(nr_states-1))
plt.yticks(ticks=tcks,labels=[x+1 for x in tcks],fontsize=fs)
plt.xticks(ticks=tcks,labels=[x+1 for x in tcks],fontsize=fs)
plt.xlabel('Future state (CA1)',fontsize=fs)
plt.ylabel('Current state (CA3)',fontsize=fs)
plt.title('Ground Truth',fontsize=fs)
if conf.store_data:
fig4.savefig(newpath+'/theory.eps',bbox_inches='tight',format='eps')
if __name__ == '__main__':
conf_supp_2d_env = {
'experiment_folder': "conf_supp_2d_env", # folder to store
'initial_trial_nr': 1, # initial trial number
'final_trial_nr': 25, # final trial number #25 for mc, 5 for td
'nr_epochs': 20, # number of epochs per trial
'cpu_percentage': 0.7, # percentage of cpus to use when running in parallel
'multiprocessing': True, # whether to run in parallel using multiprocessing
'store_data': True, # whether to store the experiment data
'mode': "MC", # mode, can be MC, TD or MX
}
run_simulation(conf_supp_2d_env)