Revision 2438e688ad719eb9870af8c032803a7367fe1140 authored by gbondanelli on 22 March 2021, 15:11:45 UTC, committed by GitHub on 22 March 2021, 15:11:45 UTC
1 parent 66019b1
fit_recurrent_model.py
from myimports import*
path = './data'
Data = load(path+'/off_responses_trialavg.npy')
DT = 0.031731
total_time = 108*DT
time_stim_presentation = 33*DT
nsteps = 1000
dt = total_time/nsteps
DT = 0.031731
total_time = 108 * DT
time_stim_presentation = 33 * DT
nsteps = 1000
dt = total_time / nsteps
## %% fit
stim = [1, 5]
dims = 100
nchunks = 20
nrandsampl = 20
cv = False
n_free_params = dims * dims
constraint1 = 'none'
l = 1
rank = 'none'
n_free_params1 = int(dims * dims)
par1 = ['random', nchunks, nrandsampl, l, n_free_params1, constraint1, rank, dt]
ntrials = 100
col = ['#F14377', '#569FD1']
for m in range(2):
i = stim[m]
figure(figsize=(2, 2))
X = expand_dims(Data[:, :, i], axis=2)
rM, MT, E = fit_dynamical_system(X, dims, cv, par1)
X_PC1 = E[2]
X_dot = E[3]
X_dotP = E[4]
X_PC2 = E[5]
X_PC_P = empty((dims, ntrials, X_PC1.shape[1]))
X_PC_P[:, :, 0] = random.multivariate_normal(X_PC2[:, 0], .0005 * identity(dims), ntrials).T
for j in range(X_PC1.shape[1] - 1):
X_PC_P[:, :, j + 1] = dot(expm(dt * (j + 1) * MT.T), X_PC_P[:, :, 0])
ax = subplot(111)
gca().set_aspect('equal', adjustable='box')
ax.spines['left'].set_bounds(-.5, .5)
ax.spines['bottom'].set_bounds(-.5, .5)
ax.spines['bottom'].set_position(('axes', -0.1))
ax.spines['left'].set_position(('axes', -0.))
d1, V1 = PCA(X_PC2)
V1[:, 1] = -V1[:, 1]
for i_trial in range(ntrials):
ln, = ax.plot(dot(V1[:, 0], X_PC_P[:, i_trial, :]), dot(V1[:, 1], X_PC_P[:, i_trial, :]), lw=2, color='#D1D9DC')
ln.set_solid_capstyle('round')
X_PC_Pmean = mean(X_PC_P, axis=1)
plot(dot(V1[:, 0], X_PC_Pmean[:, :]), dot(V1[:, 1], X_PC_Pmean[:, :]), '--', lw=1, color='#5C7C99')
plot(dot(V1[:, 0], X_PC2), dot(V1[:, 1], X_PC2), lw=1.5, color=col[m])
plot(dot(V1[:, 0], X_PC2[:, 0]), dot(V1[:, 1], X_PC2[:, 0]), '.', markersize=7, color=col[m])
xticks([-.5, 0, .5])
yticks([-.5, 0, .5])
# title('Stimulus %s'%(i+1))
if m == 0:
xlabel('PC1 - OFF (8kHz)')
ylabel('PC2 - OFF (8kHz)')
if m == 1:
xlabel('PC1 - OFF (WN)')
ylabel('PC2 - OFF (WN)')
tight_layout()
Computing file changes ...