Revision 86d380144b3f85c8951923de873893583bd25edf authored by narendramukherjee on 01 October 2020, 17:54:51 UTC, committed by GitHub on 01 October 2020, 17:54:51 UTC
Change marking of trial start time
variational_HMM_implement.py
# Import stuff!
import numpy as np
import tables
import easygui
import sys
import os
import pylab as plt
import multiprocessing as mp
import pickle
# Import PyHMM
sys.path.append('/home/narendra/Desktop/PyHMM/PyHMM')
import DiscreteHMM as dhmm
import variationalHMM as vhmm
from hinton import hinton
# Read blech.dir
f = open('blech.dir', 'r')
dir_name = []
for line in f.readlines():
dir_name.append(line)
f.close()
#---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def implement_categorical(data = None, restarts = None, num_states = None, num_emissions = None, n_cpu = None, max_iter = None, threshold = None):
pool = mp.Pool(processes = n_cpu)
results = [pool.apply_async(run_categorical, args = (data, num_states, num_emissions, restart, max_iter, threshold,)) for restart in range(restarts)]
output = [p.get() for p in results]
return output
def run_categorical(data = None, num_states = None, num_emissions = None, restart = None, max_iter = None, threshold = None):
np.random.seed(restart)
model_MAP = dhmm.CategoricalHMM(num_states = num_states, num_emissions = num_emissions, max_iter = max_iter, threshold = threshold)
model_MAP.fit(data = data, p_transitions = np.random.random((num_states, num_states)), p_emissions = np.random.random((num_states, num_emissions)), p_start = np.random.random(num_states), \
transition_pseudocounts = np.random.random((num_states, num_states)), emission_pseudocounts = np.random.random((num_states, num_emissions)), start_pseudocounts = np.random.random(num_states), verbose = False)
model_VI = vhmm.CategoricalHMM(num_states = num_states, num_emissions = num_emissions, max_iter = max_iter, threshold = threshold)
model_VI.fit(data = data, transition_hyperprior=1, emission_hyperprior=1, start_hyperprior=1, initial_emission_counts=80*model_MAP.p_emissions, initial_transition_counts=80*model_MAP.p_transitions, \
initial_start_counts=8*model_MAP.p_start, verbose = False)
return model_MAP, model_VI
#---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Pull out the NSLOTS - number of CPUs allotted
#n_cpu = int(os.getenv('NSLOTS'))
n_cpu = mp.cpu_count()
#n_cpu = int(sys.argv[1])
# Change to the data directory, get the names of all files in it, and find the .params and hdf5 (.h5) file
os.chdir(dir_name[0][:-1])
file_list = os.listdir('./')
hdf5_name = ''
params_file = ''
units_file = ''
for files in file_list:
if files[-2:] == 'h5':
hdf5_name = files
if files[-10:] == 'hmm_params':
params_file = files
if files[-9:] == 'hmm_units':
units_file = files
# Read the .hmm_params file
f = open(params_file, 'r')
params = []
for line in f.readlines():
params.append(line)
f.close()
# Assign the params to variables
min_states = int(params[0])
max_states = int(params[1])
max_iterations = int(params[2])
threshold = float(params[3])
seeds = int(params[4])
taste = int(params[5])
pre_stim = int(params[6])
bin_size = int(params[7])
pre_stim_hmm = int(params[8])
post_stim_hmm = int(params[9])
# Read the chosen units
f = open(units_file, 'r')
chosen_units = []
for line in f.readlines():
chosen_units.append(int(line))
chosen_units = np.array(chosen_units)
# Open up hdf5 file
hf5 = tables.open_file(hdf5_name, 'r+')
# Get the spike array from the required taste/input
exec('spikes = hf5.root.spike_trains.dig_in_%i.spike_array[:]' % taste)
# Slice out the required portion of the spike array, and bin it
spikes = spikes[:, chosen_units, pre_stim - pre_stim_hmm:pre_stim + post_stim_hmm]
binned_spikes = np.zeros((spikes.shape[0], int((pre_stim_hmm + post_stim_hmm)/bin_size)))
time = []
for i in range(spikes.shape[0]):
time = []
for k in range(0, spikes.shape[2], bin_size):
time.append(k - pre_stim_hmm)
n_firing_units = np.where(np.sum(spikes[i, :, k:k+bin_size], axis = 1) > 0)[0]
if n_firing_units.size:
n_firing_units = n_firing_units + 1
else:
n_firing_units = [0]
binned_spikes[i, int(k/bin_size)] = np.random.choice(n_firing_units)
# Get the laser and non-laser trials for this taste
exec('dig_in = hf5.root.spike_trains.dig_in_%i' % taste)
on_trials = np.where(dig_in.laser_durations[:] > 0.0)[0]
off_trials = np.where(dig_in.laser_durations[:] == 0.0)[0]
# Delete the categorical_vb_hmm_results node under /spike_trains/dig_in_(taste)/ if it exists
try:
hf5.remove_node('/spike_trains/dig_in_%i/categorical_vb_hmm_results' % taste, recursive = True)
except:
pass
# Then create the categorical_vb_hmm_results group
hf5.create_group('/spike_trains/dig_in_%i' % taste, 'categorical_vb_hmm_results')
hf5.flush()
# Delete the Categorical folder within HMM_plots if it exists for this taste
try:
os.system("rm -r ./variational_HMM_plots/dig_in_%i/Categorical" % taste)
except:
pass
# Make a folder for plots of Multinomial HMM analysis
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical" % taste)
# Running laser off trials first---------------------------------------------------------------------------------------------------------------------------------
# Implement a variational categorical HMM for no. of states defined by min_states and max_states
hmm_results = []
for n_states in range(min_states, max_states + 1):
# Run the variational HMM (initialized with MAP parameters)
result = implement_categorical(data = binned_spikes[off_trials, :], restarts = seeds, num_states = n_states, num_emissions = np.unique(binned_spikes).shape[0], n_cpu = n_cpu, max_iter = max_iterations, threshold = threshold)
hmm_results.append(result)
# Clean up the results from the HMM analysis by just retaining the seed (for each number of states) that has the highest ELBO
cleaned_results = []
for result in hmm_results:
# Pick only the seeds that converged
converged_results = [seed for seed in result if seed[1].converged]
# Skip to the next number of states if none of the seeds converged
if len(converged_results) == 0:
continue
else:
# Get the ELBO of all the seeds that converged
ELBO = [seed[1].ELBO[-1] for seed in converged_results]
# Append the seed with the highest ELBO to the cleaned_results
cleaned_results.append(converged_results[np.argmax(ELBO)])
# Delete the laser_off node under /spike_trains/dig_in_(taste)/categorical_vb_hmm_results/ if it exists
try:
exec("hf5.remove_node('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off' % taste, recursive = True)")
except:
pass
# Then create the laser_off node under the categorical_vb_hmm_results group
exec("hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results' % taste, 'laser_off')")
hf5.flush()
# Delete the laser_off folder within variational_HMM_plots/(taste)/Categorical if it exists for this taste
try:
os.system("rm -r ./variational_HMM_plots/dig_in_%i/Categorical/laser_off" % taste)
except:
pass
# Make a folder for plots of Multinomial HMM analysis on laser off trials
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_off" % taste)
# Go through the cleaned_results, and make plots for each state and each trial
for result in cleaned_results:
# Make a plotting directory for this number of states
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i" % (taste, result[1].num_states))
# Make a group under categorical_vb_hmm_results for this number of states
hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off' % taste, 'states_%i' % (result[1].num_states))
# Write the start, transition and emission parameters and the posterior probabilities of the states only from the variational solution
# First get the posterior probabilities of the states by doing an E-step
alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[1].E_step()
start_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[1].num_states), 'start_counts', result[1].start_counts)
transition_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'transition_counts', result[1].transition_counts)
emission_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'emission_counts', result[1].emission_counts)
posterior_proba_VB = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'posterior_proba_VB', expected_latent_state)
# Also write the ELBO to file
ELBO = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'ELBO', result[1].ELBO[-1])
hf5.flush()
# Also write the posterior probabilities of the states from the MAP solution to file
alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[0].E_step()
posterior_proba_MAP = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'posterior_proba_MAP', expected_latent_state)
hf5.flush()
# Go through laser off trials and plot the trial-wise posterior probabilities and raster plots
# First make a dictionary of colors for the rasters
raster_colors = {'regular_spiking': 'red', 'fast_spiking': 'blue', 'multi_unit': 'black'}
for i in range(off_trials.shape[0]):
# Plotting the variational solution first
fig = plt.figure()
for j in range(posterior_proba_VB.shape[0]):
plt.plot(time, len(chosen_units)*posterior_proba_VB[j, i, :])
for unit in range(len(chosen_units)):
# Determine the type of unit we are looking at - the color of the raster will depend on that
if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
unit_type = 'regular_spiking'
elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
unit_type = 'fast_spiking'
else:
unit_type = 'multi_unit'
for j in range(spikes.shape[2]):
if spikes[off_trials[i], unit, j] > 0:
plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
plt.xlabel('Time post stimulus (ms)')
plt.ylabel('Probability of HMM states')
plt.title('VB_Trial %i, Dur: %ims, Lag:%ims' % (off_trials[i]+1, dig_in.laser_durations[off_trials[i]], dig_in.laser_onset_lag[off_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/Trial_%i_VB.png' % (taste, result[1].num_states, off_trials[i] + 1))
plt.close("all")
# Now plotting the MAP solution
fig = plt.figure()
for j in range(posterior_proba_MAP.shape[0]):
plt.plot(time, len(chosen_units)*posterior_proba_MAP[j, i, :])
for unit in range(len(chosen_units)):
# Determine the type of unit we are looking at - the color of the raster will depend on that
if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
unit_type = 'regular_spiking'
elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
unit_type = 'fast_spiking'
else:
unit_type = 'multi_unit'
for j in range(spikes.shape[2]):
if spikes[off_trials[i], unit, j] > 0:
plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
plt.xlabel('Time post stimulus (ms)')
plt.ylabel('Probability of HMM states')
plt.title('MAP_Trial %i, Dur: %ims, Lag:%ims' % (off_trials[i]+1, dig_in.laser_durations[off_trials[i]], dig_in.laser_onset_lag[off_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/Trial_%i_MAP.png' % (taste, result[1].num_states, off_trials[i] + 1))
plt.close("all")
# Also pickle the model objects themselves to file in the plotting directory
with open("variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/MAP_model.out" % (taste, result[0].num_states), "wb") as f:
pickle.dump(result[0], f, pickle.HIGHEST_PROTOCOL)
with open("variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/variational_model.out" % (taste, result[1].num_states), "wb") as f:
pickle.dump(result[1], f, pickle.HIGHEST_PROTOCOL)
# Laser off trials done------------------------------------------------------------------------------------------------------------------------------------------
# Running laser on trials----------------------------------------------------------------------------------------------------------------------------------------
# Implement a variational categorical HMM for no. of states defined by min_states and max_states
hmm_results = []
for n_states in range(min_states, max_states + 1):
# Run the variational HMM (initialized with MAP parameters)
result = implement_categorical(data = binned_spikes[on_trials, :], restarts = seeds, num_states = n_states, num_emissions = np.unique(binned_spikes).shape[0], n_cpu = n_cpu, max_iter = max_iterations, threshold = threshold)
hmm_results.append(result)
# Clean up the results from the HMM analysis by just retaining the seed (for each number of states) that has the highest ELBO
cleaned_results = []
for result in hmm_results:
# Pick only the seeds that converged
converged_results = [seed for seed in result if seed[1].converged]
# Skip to the next number of states if none of the seeds converged
if len(converged_results) == 0:
continue
else:
# Get the ELBO of all the seeds that converged
ELBO = [seed[1].ELBO[-1] for seed in converged_results]
# Append the seed with the highest ELBO to the cleaned_results
cleaned_results.append(converged_results[np.argmax(ELBO)])
# Delete the laser_on node under /spike_trains/dig_in_(taste)/categorical_vb_hmm_results/ if it exists
try:
exec("hf5.remove_node('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on' % taste, recursive = True)")
except:
pass
# Then create the laser_off node under the categorical_vb_hmm_results group
exec("hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results' % taste, 'laser_on')")
hf5.flush()
# Delete the laser_off folder within variational_HMM_plots/(taste)/Categorical if it exists for this taste
try:
os.system("rm -r ./variational_HMM_plots/dig_in_%i/Categorical/laser_on" % taste)
except:
pass
# Make a folder for plots of Multinomial HMM analysis on laser off trials
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_on" % taste)
# Go through the cleaned_results, and make plots for each state and each trial
for result in cleaned_results:
# Make a plotting directory for this number of states
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i" % (taste, result[1].num_states))
# Make a group under categorical_vb_hmm_results for this number of states
hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on' % taste, 'states_%i' % (result[1].num_states))
# Write the start, transition and emission parameters and the posterior probabilities of the states only from the variational solution
# First get the posterior probabilities of the states by doing an E-step
alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[1].E_step()
start_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[1].num_states), 'start_counts', result[1].start_counts)
transition_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'transition_counts', result[1].transition_counts)
emission_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'emission_counts', result[1].emission_counts)
posterior_proba_VB = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'posterior_proba_VB', expected_latent_state)
# Also write the ELBO to file
ELBO = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'ELBO', result[1].ELBO[-1])
hf5.flush()
# Also write the posterior probabilities of the states from the MAP solution to file
alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[0].E_step()
posterior_proba_MAP = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'posterior_proba_MAP', expected_latent_state)
hf5.flush()
# Go through laser off trials and plot the trial-wise posterior probabilities and raster plots
# First make a dictionary of colors for the rasters
raster_colors = {'regular_spiking': 'red', 'fast_spiking': 'blue', 'multi_unit': 'black'}
for i in range(on_trials.shape[0]):
# Plotting the variational solution first
fig = plt.figure()
for j in range(posterior_proba_VB.shape[0]):
plt.plot(time, len(chosen_units)*posterior_proba_VB[j, i, :])
for unit in range(len(chosen_units)):
# Determine the type of unit we are looking at - the color of the raster will depend on that
if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
unit_type = 'regular_spiking'
elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
unit_type = 'fast_spiking'
else:
unit_type = 'multi_unit'
for j in range(spikes.shape[2]):
if spikes[on_trials[i], unit, j] > 0:
plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
plt.xlabel('Time post stimulus (ms)')
plt.ylabel('Probability of HMM states')
plt.title('VB_Trial %i, Dur: %ims, Lag:%ims' % (on_trials[i]+1, dig_in.laser_durations[on_trials[i]], dig_in.laser_onset_lag[on_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/Trial_%i_VB.png' % (taste, result[1].num_states, on_trials[i] + 1))
plt.close("all")
# Now plotting the MAP solution
fig = plt.figure()
for j in range(posterior_proba_MAP.shape[0]):
plt.plot(time, len(chosen_units)*posterior_proba_MAP[j, i, :])
for unit in range(len(chosen_units)):
# Determine the type of unit we are looking at - the color of the raster will depend on that
if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
unit_type = 'regular_spiking'
elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
unit_type = 'fast_spiking'
else:
unit_type = 'multi_unit'
for j in range(spikes.shape[2]):
if spikes[on_trials[i], unit, j] > 0:
plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
plt.xlabel('Time post stimulus (ms)')
plt.ylabel('Probability of HMM states')
plt.title('MAP_Trial %i, Dur: %ims, Lag:%ims' % (on_trials[i]+1, dig_in.laser_durations[on_trials[i]], dig_in.laser_onset_lag[on_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/Trial_%i_MAP.png' % (taste, result[1].num_states, on_trials[i] + 1))
plt.close("all")
# Also pickle the model objects themselves to file in the plotting directory
with open("variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/MAP_model.out" % (taste, result[0].num_states), "wb") as f:
pickle.dump(result[0], f, pickle.HIGHEST_PROTOCOL)
with open("variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/variational_model.out" % (taste, result[1].num_states), "wb") as f:
pickle.dump(result[1], f, pickle.HIGHEST_PROTOCOL)
# Laser on trials done-------------------------------------------------------------------------------------------------------------------------------------------
# Close the HDF5 file
hf5.close()
Computing file changes ...