Revision 86d380144b3f85c8951923de873893583bd25edf authored by narendramukherjee on 01 October 2020, 17:54:51 UTC, committed by GitHub on 01 October 2020, 17:54:51 UTC
Change marking of trial start time
2 parent s c98da81 + f1007b5
Raw File
variational_HMM_implement.py
# Import stuff!
import numpy as np
import tables
import easygui
import sys
import os
import pylab as plt
import multiprocessing as mp
import pickle
# Import PyHMM
sys.path.append('/home/narendra/Desktop/PyHMM/PyHMM')
import DiscreteHMM as dhmm
import variationalHMM as vhmm
from hinton import hinton

# Read blech.dir
f = open('blech.dir', 'r')
dir_name = []
for line in f.readlines():
	dir_name.append(line)
f.close()

#---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def implement_categorical(data = None, restarts = None, num_states = None, num_emissions = None, n_cpu = None, max_iter = None, threshold = None):
	pool = mp.Pool(processes = n_cpu)
	results = [pool.apply_async(run_categorical, args = (data, num_states, num_emissions, restart, max_iter, threshold,)) for restart in range(restarts)]

	output = [p.get() for p in results]
	return output

def run_categorical(data = None, num_states = None, num_emissions = None, restart = None, max_iter = None, threshold = None):
	np.random.seed(restart)
	model_MAP = dhmm.CategoricalHMM(num_states = num_states, num_emissions = num_emissions, max_iter = max_iter, threshold = threshold)
	model_MAP.fit(data = data, p_transitions = np.random.random((num_states, num_states)), p_emissions = np.random.random((num_states, num_emissions)), p_start = np.random.random(num_states), \
transition_pseudocounts = np.random.random((num_states, num_states)), emission_pseudocounts = np.random.random((num_states, num_emissions)), start_pseudocounts = np.random.random(num_states), verbose = False)

	model_VI = vhmm.CategoricalHMM(num_states = num_states, num_emissions = num_emissions, max_iter = max_iter, threshold = threshold)
	model_VI.fit(data = data, transition_hyperprior=1, emission_hyperprior=1, start_hyperprior=1, initial_emission_counts=80*model_MAP.p_emissions, initial_transition_counts=80*model_MAP.p_transitions, \
initial_start_counts=8*model_MAP.p_start, verbose = False)

	return model_MAP, model_VI

#---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Pull out the NSLOTS - number of CPUs allotted
#n_cpu = int(os.getenv('NSLOTS'))
n_cpu = mp.cpu_count()
#n_cpu = int(sys.argv[1])

# Change to the data directory, get the names of all files in it, and find the .params and hdf5 (.h5) file
os.chdir(dir_name[0][:-1])
file_list = os.listdir('./')
hdf5_name = ''
params_file = ''
units_file = ''
for files in file_list:
	if files[-2:] == 'h5':
		hdf5_name = files
	if files[-10:] == 'hmm_params':
		params_file = files
	if files[-9:] == 'hmm_units':
		units_file = files

# Read the .hmm_params file
f = open(params_file, 'r')
params = []
for line in f.readlines():
	params.append(line)
f.close()

# Assign the params to variables
min_states = int(params[0])
max_states = int(params[1])
max_iterations = int(params[2])
threshold = float(params[3])
seeds = int(params[4])
taste = int(params[5])
pre_stim = int(params[6])
bin_size = int(params[7])
pre_stim_hmm = int(params[8])
post_stim_hmm = int(params[9])

# Read the chosen units
f = open(units_file, 'r')
chosen_units = []
for line in f.readlines():
	chosen_units.append(int(line))
chosen_units = np.array(chosen_units)

# Open up hdf5 file
hf5 = tables.open_file(hdf5_name, 'r+')

# Get the spike array from the required taste/input
exec('spikes = hf5.root.spike_trains.dig_in_%i.spike_array[:]' % taste)

# Slice out the required portion of the spike array, and bin it
spikes = spikes[:, chosen_units, pre_stim - pre_stim_hmm:pre_stim + post_stim_hmm]
binned_spikes = np.zeros((spikes.shape[0], int((pre_stim_hmm + post_stim_hmm)/bin_size)))
time = []
for i in range(spikes.shape[0]):
	time = []
	for k in range(0, spikes.shape[2], bin_size):
		time.append(k - pre_stim_hmm)
		n_firing_units = np.where(np.sum(spikes[i, :, k:k+bin_size], axis = 1) > 0)[0]
		if n_firing_units.size:
			n_firing_units = n_firing_units + 1 
		else:
			n_firing_units = [0]
		binned_spikes[i, int(k/bin_size)] = np.random.choice(n_firing_units)

# Get the laser and non-laser trials for this taste
exec('dig_in = hf5.root.spike_trains.dig_in_%i' % taste)
on_trials = np.where(dig_in.laser_durations[:] > 0.0)[0]
off_trials = np.where(dig_in.laser_durations[:] == 0.0)[0]

# Delete the categorical_vb_hmm_results node under /spike_trains/dig_in_(taste)/ if it exists
try:
	hf5.remove_node('/spike_trains/dig_in_%i/categorical_vb_hmm_results' % taste, recursive = True)
except:
	pass

# Then create the categorical_vb_hmm_results group
hf5.create_group('/spike_trains/dig_in_%i' % taste, 'categorical_vb_hmm_results')
hf5.flush()

# Delete the Categorical folder within HMM_plots if it exists for this taste
try:
	os.system("rm -r ./variational_HMM_plots/dig_in_%i/Categorical" % taste)
except:
	pass	

# Make a folder for plots of Multinomial HMM analysis
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical" % taste)

# Running laser off trials first---------------------------------------------------------------------------------------------------------------------------------

# Implement a variational categorical HMM for no. of states defined by min_states and max_states
hmm_results = []
for n_states in range(min_states, max_states + 1):
	# Run the variational HMM (initialized with MAP parameters)
	result = implement_categorical(data = binned_spikes[off_trials, :], restarts = seeds, num_states = n_states, num_emissions = np.unique(binned_spikes).shape[0], n_cpu = n_cpu, max_iter = max_iterations, threshold = threshold)
	hmm_results.append(result)

# Clean up the results from the HMM analysis by just retaining the seed (for each number of states) that has the highest ELBO
cleaned_results = []
for result in hmm_results:
	# Pick only the seeds that converged
	converged_results = [seed for seed in result if seed[1].converged]
	# Skip to the next number of states if none of the seeds converged
	if len(converged_results) == 0:
		continue
	else:
		# Get the ELBO of all the seeds that converged
		ELBO = [seed[1].ELBO[-1] for seed in converged_results]
		# Append the seed with the highest ELBO to the cleaned_results
		cleaned_results.append(converged_results[np.argmax(ELBO)])

# Delete the laser_off node under /spike_trains/dig_in_(taste)/categorical_vb_hmm_results/ if it exists
try:
	exec("hf5.remove_node('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off' % taste, recursive = True)")
except:
	pass

# Then create the laser_off node under the categorical_vb_hmm_results group
exec("hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results' % taste, 'laser_off')")
hf5.flush()

# Delete the laser_off folder within variational_HMM_plots/(taste)/Categorical if it exists for this taste
try:
	os.system("rm -r ./variational_HMM_plots/dig_in_%i/Categorical/laser_off" % taste)
except:
	pass	

# Make a folder for plots of Multinomial HMM analysis on laser off trials
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_off" % taste)

# Go through the cleaned_results, and make plots for each state and each trial
for result in cleaned_results:
	# Make a plotting directory for this number of states
	os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i" % (taste, result[1].num_states))

	# Make a group under categorical_vb_hmm_results for this number of states
	hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off' % taste, 'states_%i' % (result[1].num_states))

	# Write the start, transition and emission parameters and the posterior probabilities of the states only from the variational solution
	# First get the posterior probabilities of the states by doing an E-step
	alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[1].E_step()
	start_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[1].num_states), 'start_counts', result[1].start_counts)
	transition_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'transition_counts', result[1].transition_counts)
	emission_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'emission_counts', result[1].emission_counts)
	posterior_proba_VB = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'posterior_proba_VB', expected_latent_state)
	# Also write the ELBO to file
	ELBO = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'ELBO', result[1].ELBO[-1])
	hf5.flush()
	# Also write the posterior probabilities of the states from the MAP solution to file
	alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[0].E_step()
	posterior_proba_MAP = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_off/states_%i' % (taste, result[0].num_states), 'posterior_proba_MAP', expected_latent_state)
	hf5.flush()

	# Go through laser off trials and plot the trial-wise posterior probabilities and raster plots
	# First make a dictionary of colors for the rasters
	raster_colors = {'regular_spiking': 'red', 'fast_spiking': 'blue', 'multi_unit': 'black'}
	for i in range(off_trials.shape[0]):
		# Plotting the variational solution first
		fig = plt.figure()
		for j in range(posterior_proba_VB.shape[0]):
			plt.plot(time, len(chosen_units)*posterior_proba_VB[j, i, :])
		for unit in range(len(chosen_units)):
			# Determine the type of unit we are looking at - the color of the raster will depend on that
			if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
				unit_type = 'regular_spiking'
			elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
				unit_type = 'fast_spiking'
			else:
				unit_type = 'multi_unit'
			for j in range(spikes.shape[2]):
				if spikes[off_trials[i], unit, j] > 0:
					plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
		plt.xlabel('Time post stimulus (ms)')
		plt.ylabel('Probability of HMM states')
		plt.title('VB_Trial %i, Dur: %ims, Lag:%ims' % (off_trials[i]+1, dig_in.laser_durations[off_trials[i]], dig_in.laser_onset_lag[off_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
		fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/Trial_%i_VB.png' % (taste, result[1].num_states, off_trials[i] + 1))
		plt.close("all")

		# Now plotting the MAP solution
		fig = plt.figure()
		for j in range(posterior_proba_MAP.shape[0]):
			plt.plot(time, len(chosen_units)*posterior_proba_MAP[j, i, :])
		for unit in range(len(chosen_units)):
			# Determine the type of unit we are looking at - the color of the raster will depend on that
			if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
				unit_type = 'regular_spiking'
			elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
				unit_type = 'fast_spiking'
			else:
				unit_type = 'multi_unit'
			for j in range(spikes.shape[2]):
				if spikes[off_trials[i], unit, j] > 0:
					plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
		plt.xlabel('Time post stimulus (ms)')
		plt.ylabel('Probability of HMM states')
		plt.title('MAP_Trial %i, Dur: %ims, Lag:%ims' % (off_trials[i]+1, dig_in.laser_durations[off_trials[i]], dig_in.laser_onset_lag[off_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
		fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/Trial_%i_MAP.png' % (taste, result[1].num_states, off_trials[i] + 1))
		plt.close("all")

	# Also pickle the model objects themselves to file in the plotting directory
	with open("variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/MAP_model.out" % (taste, result[0].num_states), "wb") as f:
		pickle.dump(result[0], f, pickle.HIGHEST_PROTOCOL)
	with open("variational_HMM_plots/dig_in_%i/Categorical/laser_off/states_%i/variational_model.out" % (taste, result[1].num_states), "wb") as f:
		pickle.dump(result[1], f, pickle.HIGHEST_PROTOCOL)

# Laser off trials done------------------------------------------------------------------------------------------------------------------------------------------


# Running laser on trials----------------------------------------------------------------------------------------------------------------------------------------

# Implement a variational categorical HMM for no. of states defined by min_states and max_states
hmm_results = []
for n_states in range(min_states, max_states + 1):
	# Run the variational HMM (initialized with MAP parameters)
	result = implement_categorical(data = binned_spikes[on_trials, :], restarts = seeds, num_states = n_states, num_emissions = np.unique(binned_spikes).shape[0], n_cpu = n_cpu, max_iter = max_iterations, threshold = threshold)
	hmm_results.append(result)

# Clean up the results from the HMM analysis by just retaining the seed (for each number of states) that has the highest ELBO
cleaned_results = []
for result in hmm_results:
	# Pick only the seeds that converged
	converged_results = [seed for seed in result if seed[1].converged]
	# Skip to the next number of states if none of the seeds converged
	if len(converged_results) == 0:
		continue
	else:
		# Get the ELBO of all the seeds that converged
		ELBO = [seed[1].ELBO[-1] for seed in converged_results]
		# Append the seed with the highest ELBO to the cleaned_results
		cleaned_results.append(converged_results[np.argmax(ELBO)])

# Delete the laser_on node under /spike_trains/dig_in_(taste)/categorical_vb_hmm_results/ if it exists
try:
	exec("hf5.remove_node('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on' % taste, recursive = True)")
except:
	pass

# Then create the laser_off node under the categorical_vb_hmm_results group
exec("hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results' % taste, 'laser_on')")
hf5.flush()

# Delete the laser_off folder within variational_HMM_plots/(taste)/Categorical if it exists for this taste
try:
	os.system("rm -r ./variational_HMM_plots/dig_in_%i/Categorical/laser_on" % taste)
except:
	pass	

# Make a folder for plots of Multinomial HMM analysis on laser off trials
os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_on" % taste)

# Go through the cleaned_results, and make plots for each state and each trial
for result in cleaned_results:
	# Make a plotting directory for this number of states
	os.mkdir("variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i" % (taste, result[1].num_states))

	# Make a group under categorical_vb_hmm_results for this number of states
	hf5.create_group('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on' % taste, 'states_%i' % (result[1].num_states))

	# Write the start, transition and emission parameters and the posterior probabilities of the states only from the variational solution
	# First get the posterior probabilities of the states by doing an E-step
	alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[1].E_step()
	start_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[1].num_states), 'start_counts', result[1].start_counts)
	transition_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'transition_counts', result[1].transition_counts)
	emission_counts = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'emission_counts', result[1].emission_counts)
	posterior_proba_VB = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'posterior_proba_VB', expected_latent_state)
	# Also write the ELBO to file
	ELBO = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'ELBO', result[1].ELBO[-1])
	hf5.flush()
	# Also write the posterior probabilities of the states from the MAP solution to file
	alpha, beta, scaling, expected_latent_state, expected_latent_state_pair = result[0].E_step()
	posterior_proba_MAP = hf5.create_array('/spike_trains/dig_in_%i/categorical_vb_hmm_results/laser_on/states_%i' % (taste, result[0].num_states), 'posterior_proba_MAP', expected_latent_state)
	hf5.flush()

	# Go through laser off trials and plot the trial-wise posterior probabilities and raster plots
	# First make a dictionary of colors for the rasters
	raster_colors = {'regular_spiking': 'red', 'fast_spiking': 'blue', 'multi_unit': 'black'}
	for i in range(on_trials.shape[0]):
		# Plotting the variational solution first
		fig = plt.figure()
		for j in range(posterior_proba_VB.shape[0]):
			plt.plot(time, len(chosen_units)*posterior_proba_VB[j, i, :])
		for unit in range(len(chosen_units)):
			# Determine the type of unit we are looking at - the color of the raster will depend on that
			if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
				unit_type = 'regular_spiking'
			elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
				unit_type = 'fast_spiking'
			else:
				unit_type = 'multi_unit'
			for j in range(spikes.shape[2]):
				if spikes[on_trials[i], unit, j] > 0:
					plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
		plt.xlabel('Time post stimulus (ms)')
		plt.ylabel('Probability of HMM states')
		plt.title('VB_Trial %i, Dur: %ims, Lag:%ims' % (on_trials[i]+1, dig_in.laser_durations[on_trials[i]], dig_in.laser_onset_lag[on_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
		fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/Trial_%i_VB.png' % (taste, result[1].num_states, on_trials[i] + 1))
		plt.close("all")

		# Now plotting the MAP solution
		fig = plt.figure()
		for j in range(posterior_proba_MAP.shape[0]):
			plt.plot(time, len(chosen_units)*posterior_proba_MAP[j, i, :])
		for unit in range(len(chosen_units)):
			# Determine the type of unit we are looking at - the color of the raster will depend on that
			if hf5.root.unit_descriptor[chosen_units[unit]]['regular_spiking'] == 1:
				unit_type = 'regular_spiking'
			elif hf5.root.unit_descriptor[chosen_units[unit]]['fast_spiking'] == 1:
				unit_type = 'fast_spiking'
			else:
				unit_type = 'multi_unit'
			for j in range(spikes.shape[2]):
				if spikes[on_trials[i], unit, j] > 0:
					plt.vlines(j - pre_stim_hmm, unit, unit + 0.5, color = raster_colors[unit_type], linewidth = 0.5)
		plt.xlabel('Time post stimulus (ms)')
		plt.ylabel('Probability of HMM states')
		plt.title('MAP_Trial %i, Dur: %ims, Lag:%ims' % (on_trials[i]+1, dig_in.laser_durations[on_trials[i]], dig_in.laser_onset_lag[on_trials[i]]) + '\n' + 'RSU: red, FS: blue, Multi: black')
		fig.savefig('variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/Trial_%i_MAP.png' % (taste, result[1].num_states, on_trials[i] + 1))
		plt.close("all")

	# Also pickle the model objects themselves to file in the plotting directory
	with open("variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/MAP_model.out" % (taste, result[0].num_states), "wb") as f:
		pickle.dump(result[0], f, pickle.HIGHEST_PROTOCOL)
	with open("variational_HMM_plots/dig_in_%i/Categorical/laser_on/states_%i/variational_model.out" % (taste, result[1].num_states), "wb") as f:
		pickle.dump(result[1], f, pickle.HIGHEST_PROTOCOL)

# Laser on trials done-------------------------------------------------------------------------------------------------------------------------------------------

# Close the HDF5 file
hf5.close()

back to top