Revision 86d380144b3f85c8951923de873893583bd25edf authored by narendramukherjee on 01 October 2020, 17:54:51 UTC, committed by GitHub on 01 October 2020, 17:54:51 UTC
Change marking of trial start time
blech_post_process.py
import os
import tables
import numpy as np
import easygui
import ast
import pylab as plt
from sklearn.mixture import GaussianMixture
import blech_waveforms_datashader
# Get directory where the hdf5 file sits, and change to that directory
dir_name = easygui.diropenbox()
os.chdir(dir_name)
# Clean up the memory monitor files, pass if clean up has been done already
if not os.path.exists('./memory_monitor_clustering/memory_usage.txt'):
file_list = os.listdir('./memory_monitor_clustering')
f = open('./memory_monitor_clustering/memory_usage.txt', 'w')
for files in file_list:
try:
mem_usage = np.loadtxt('./memory_monitor_clustering/' + files)
print('electrode'+files[:-4], '\t', str(mem_usage)+'MB', file=f)
os.system('rm ' + './memory_monitor_clustering/' + files)
except:
pass
f.close()
# Look for the hdf5 file in the directory
file_list = os.listdir('./')
hdf5_name = ''
for files in file_list:
if files[-2:] == 'h5':
hdf5_name = files
# Open the hdf5 file
hf5 = tables.open_file(hdf5_name, 'r+')
# Delete the raw node, if it exists in the hdf5 file, to cut down on file size
try:
hf5.remove_node('/raw', recursive = 1)
# And if successful, close the currently open hdf5 file and ptrepack the file
hf5.close()
print("Raw recordings removed")
os.system("ptrepack --chunkshape=auto --propindexes --complevel=9 --complib=blosc " + hdf5_name + " " + hdf5_name[:-3] + "_repacked.h5")
# Delete the old (raw and big) hdf5 file
os.system("rm " + hdf5_name)
# And open the new, repacked file
hf5 = tables.open_file(hdf5_name[:-3] + "_repacked.h5", 'r+')
print("File repacked")
except:
print("Raw recordings have already been removed, so moving on ..")
# Make the sorted_units group in the hdf5 file if it doesn't already exist
try:
hf5.create_group('/', 'sorted_units')
except:
pass
# Define a unit_descriptor class to be used to add things (anything!) about the sorted units to a pytables table
class unit_descriptor(tables.IsDescription):
electrode_number = tables.Int32Col()
single_unit = tables.Int32Col()
regular_spiking = tables.Int32Col()
fast_spiking = tables.Int32Col()
# Make a table under /sorted_units describing the sorted units. If unit_descriptor already exists, just open it up in the variable table
try:
table = hf5.create_table('/', 'unit_descriptor', description = unit_descriptor)
except:
table = hf5.root.unit_descriptor
# Run an infinite loop as long as the user wants to pick clusters from the electrodes
while True:
# Get electrode number from user
electrode_num = easygui.multenterbox(msg = 'Which electrode do you want to choose? Hit cancel to exit', fields = ['Electrode #'])
# Break if wrong input/cancel command was given
try:
electrode_num = int(electrode_num[0])
except:
break
# Get the number of clusters in the chosen solution
num_clusters = easygui.multenterbox(msg = 'Which solution do you want to choose for electrode %i?' % electrode_num, fields = ['Number of clusters in the solution'])
num_clusters = int(num_clusters[0])
# Load data from the chosen electrode and solution
spike_waveforms = np.load('./spike_waveforms/electrode%i/spike_waveforms.npy' % electrode_num)
spike_times = np.load('./spike_times/electrode%i/spike_times.npy' % electrode_num)
pca_slices = np.load('./spike_waveforms/electrode%i/pca_waveforms.npy' % electrode_num)
energy = np.load('./spike_waveforms/electrode%i/energy.npy' % electrode_num)
amplitudes = np.load('./spike_waveforms/electrode%i/spike_amplitudes.npy' % electrode_num)
predictions = np.load('./clustering_results/electrode%i/clusters%i/predictions.npy' % (electrode_num, num_clusters))
# Get cluster choices from the chosen solution
clusters = easygui.multchoicebox(msg = 'Which clusters do you want to choose?', choices = tuple([str(i) for i in range(int(np.max(predictions) + 1))]))
# Check if the user wants to merge clusters if more than 1 cluster was chosen. Else ask if the user wants to split/re-cluster the chosen cluster
merge = False
re_cluster = False
if len(clusters) > 1:
merge = easygui.multchoicebox(msg = 'I want to merge these clusters into one unit (True = Yes, False = No)', choices = ('True', 'False'))
merge = ast.literal_eval(merge[0])
else:
re_cluster = easygui.multchoicebox(msg = 'I want to split this cluster (True = Yes, False = No)', choices = ('True', 'False'))
re_cluster = ast.literal_eval(re_cluster[0])
# If the user asked to split/re-cluster, ask them for the clustering parameters and perform clustering
split_predictions = []
chosen_split = 0
if re_cluster:
# Get clustering parameters from user
clustering_params = easygui.multenterbox(msg = 'Fill in the parameters for re-clustering (using a GMM)', fields = ['Number of clusters', 'Maximum number of iterations (1000 is more than enough)', 'Convergence criterion (usually 0.0001)', 'Number of random restarts for GMM (10 is more than enough)'])
n_clusters = int(clustering_params[0])
n_iter = int(clustering_params[1])
thresh = float(clustering_params[2])
n_restarts = int(clustering_params[3])
# Make data array to be put through the GMM - 5 components: 3 PCs, scaled energy, amplitude
this_cluster = np.where(predictions == int(clusters[0]))[0]
n_pc = 3
data = np.zeros((len(this_cluster), n_pc + 2))
data[:,2:] = pca_slices[this_cluster,:n_pc]
data[:,0] = energy[this_cluster]/np.max(energy[this_cluster])
data[:,1] = np.abs(amplitudes[this_cluster])/np.max(np.abs(amplitudes[this_cluster]))
# Cluster the data
g = GaussianMixture(n_components = n_clusters, covariance_type = 'full', tol = thresh, max_iter = n_iter, n_init = n_restarts)
g.fit(data)
# Show the cluster plots if the solution converged
if g.converged_:
split_predictions = g.predict(data)
x = np.arange(len(spike_waveforms[0])/10) + 1
for cluster in range(n_clusters):
split_points = np.where(split_predictions == cluster)[0]
# plt.figure(cluster)
slices_dejittered = spike_waveforms[this_cluster, :] # Waveforms and times from the chosen cluster
times_dejittered = spike_times[this_cluster]
times_dejittered = times_dejittered[split_points] # Waveforms and times from the chosen split of the chosen cluster
ISIs = np.ediff1d(np.sort(times_dejittered))/30.0
violations1 = 100.0*float(np.sum(ISIs < 1.0)/split_points.shape[0])
violations2 = 100.0*float(np.sum(ISIs < 2.0)/split_points.shape[0])
fig, ax = blech_waveforms_datashader.waveforms_datashader(slices_dejittered[split_points, :], x)
# plt.plot(x-15, slices_dejittered[split_points, :].T, linewidth = 0.01, color = 'red')
ax.set_xlabel('Sample (30 samples per ms)')
ax.set_ylabel('Voltage (microvolts)')
ax.set_title("Split Cluster{:d}, 2ms ISI violations={:.1f} percent".format(cluster, violations2) + "\n" + "1ms ISI violations={:.1f}%, Number of waveforms={:d}".format(violations1, split_points.shape[0]))
else:
print("Solution did not converge - try again with higher number of iterations or lower convergence criterion")
continue
plt.show()
# Ask the user for the split clusters they want to choose
chosen_split = easygui.multchoicebox(msg = 'Which split cluster do you want to choose? Hit cancel to exit', choices = tuple([str(i) for i in range(n_clusters)]))
try:
chosen_split = int(chosen_split[0])
except:
continue
# Get list of existing nodes/groups under /sorted_units
node_list = hf5.list_nodes('/sorted_units')
# If node_list is empty, start naming units from 000
unit_name = ''
max_unit = 0
if node_list == []:
unit_name = 'unit%03d' % 0
# Else name the new unit by incrementing the last unit by 1
else:
unit_numbers = []
for node in node_list:
unit_numbers.append(node._v_pathname.split('/')[-1][-3:])
unit_numbers[-1] = int(unit_numbers[-1])
unit_numbers = np.array(unit_numbers)
max_unit = np.max(unit_numbers)
unit_name = 'unit%03d' % int(max_unit + 1)
# Get a new unit_descriptor table row for this new unit
unit_description = table.row
# If the user re-clustered/split clusters, add the chosen clusters in split_clusters
if re_cluster:
hf5.create_group('/sorted_units', unit_name)
unit_waveforms = spike_waveforms[np.where(predictions == int(clusters[0]))[0], :] # Waveforms of originally chosen cluster
unit_waveforms = unit_waveforms[np.where(split_predictions == chosen_split)[0], :] # Subsetting this set of waveforms to include only the chosen split
unit_times = spike_times[np.where(predictions == int(clusters[0]))[0]] # Do the same thing for the spike times
unit_times = unit_times[np.where(split_predictions == chosen_split)[0]]
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that this is a beautiful single unit (True = Yes, False = No)', choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
# If only 1 cluster was chosen (and it wasn't split), add that as a new unit in /sorted_units. Ask if the isolated unit is an almost-SURE single unit
elif len(clusters) == 1:
hf5.create_group('/sorted_units', unit_name)
unit_waveforms = spike_waveforms[np.where(predictions == int(clusters[0]))[0], :]
unit_times = spike_times[np.where(predictions == int(clusters[0]))[0]]
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that this is a beautiful single unit (True = Yes, False = No)', choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
else:
# If the chosen units are going to be merged, merge them
if merge:
unit_waveforms = []
unit_times = []
for cluster in clusters:
if unit_waveforms == []:
unit_waveforms = spike_waveforms[np.where(predictions == int(cluster))[0], :]
unit_times = spike_times[np.where(predictions == int(cluster))[0]]
else:
unit_waveforms = np.concatenate((unit_waveforms, spike_waveforms[np.where(predictions == int(cluster))[0], :]))
unit_times = np.concatenate((unit_times, spike_times[np.where(predictions == int(cluster))[0]]))
# Show the merged cluster to the user, and ask if they still want to merge
x = np.arange(len(unit_waveforms[0])/10) + 1
fig, ax = blech_waveforms_datashader.waveforms_datashader(unit_waveforms, x)
# plt.plot(x - 15, unit_waveforms[:, ::10].T, linewidth = 0.01, color = 'red')
ax.set_xlabel('Sample (30 samples per ms)')
ax.set_ylabel('Voltage (microvolts)')
ax.set_title('Merged cluster, No. of waveforms={:d}'.format(unit_waveforms.shape[0]))
plt.show()
# Warn the user about the frequency of ISI violations in the merged unit
ISIs = np.ediff1d(np.sort(unit_times))/30.0
violations1 = 100.0*float(np.sum(ISIs < 1.0)/len(unit_times))
violations2 = 100.0*float(np.sum(ISIs < 2.0)/len(unit_times))
proceed = easygui.multchoicebox(msg = 'My merged cluster has %.1f percent (<2ms) and %.1f percent (<1ms) ISI violations out of %i total waveforms. I want to still merge these clusters into one unit (True = Yes, False = No)' % (violations2, violations1, len(unit_times)), choices = ('True', 'False'))
proceed = ast.literal_eval(proceed[0])
# Create unit if the user agrees to proceed, else abort and go back to start of the loop
if proceed:
hf5.create_group('/sorted_units', unit_name)
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that this is a beautiful single unit (True = Yes, False = No)', choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
else:
continue
# Otherwise include each cluster as a separate unit
else:
for cluster in clusters:
hf5.create_group('/sorted_units', unit_name)
unit_waveforms = spike_waveforms[np.where(predictions == int(cluster))[0], :]
unit_times = spike_times[np.where(predictions == int(cluster))[0]]
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that electrode: %i cluster: %i is a beautiful single unit (True = Yes, False = No)' % (electrode_num, int(cluster)), choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
# Finally increment max_unit and create a new unit name
max_unit += 1
unit_name = 'unit%03d' % int(max_unit + 1)
# Get a new unit_descriptor table row for this new unit
unit_description = table.row
# Close the hdf5 file
hf5.close()
Computing file changes ...