Revision 86d380144b3f85c8951923de873893583bd25edf authored by narendramukherjee on 01 October 2020, 17:54:51 UTC, committed by GitHub on 01 October 2020, 17:54:51 UTC
Change marking of trial start time
blech_units_similarity.py
# Import stuff!
import numpy as np
import tables
import easygui
import sys
import os
from numba import jit
# Numba compiled function to compute the number of spikes in this_unit_times that are within 1 ms of a spike in other_unit_times, and vice versa
@jit(nogil = True)
def unit_similarity(this_unit_times, other_unit_times):
this_unit_counter = 0
other_unit_counter = 0
for i in range(len(this_unit_times)):
for j in range(len(other_unit_times)):
if np.abs(this_unit_times[i] - other_unit_times[j]) <= 1.0:
this_unit_counter += 1
other_unit_counter += 1
return this_unit_counter, other_unit_counter
# Ask for the directory where the hdf5 file sits, and change to that directory
dir_name = easygui.diropenbox()
os.chdir(dir_name)
# Ask the user for a cutoff percentage for unit distances - if two units have distances greater than the cutoff, then will be written to file
similarity_cutoff = easygui.multenterbox(msg = "Enter a cutoff percentage for unit similarities - violations will be written to file (0-100)", fields = ["Unit similarity cutoff percent"])
similarity_cutoff = float(similarity_cutoff[0])
# Open a file to write these unit distance violations to - these units are likely the same and one of them will need to be removed from the HDF5 file
unit_similarity_violations = open("unit_similarity_violations.txt", "w")
print("Unit number 1" + "\t" + "Unit number 2", file = unit_similarity_violations)
# Get the names of all files in the current directory, and find the hdf5 (.h5) file
file_list = os.listdir('./')
hdf5_name = ''
for files in file_list:
if files[-2:] == 'h5':
hdf5_name = files
# Open up the hdf5 file
hf5 = tables.open_file(hdf5_name, 'r+')
# Get all the units from the hdf5 file
units = hf5.list_nodes('/sorted_units')
# Now go through the units one by one, and get the pairwise distances between them
# Similarity is defined as the percentage of spikes of the reference unit that have a spike from the compared unit within 1 ms
print("==================")
print("Similarity calculation starting")
unit_distances = np.zeros((len(units), len(units)))
for this_unit in range(len(units)):
this_unit_times = (units[this_unit].times[:])/30.0
for other_unit in range(len(units)):
if other_unit < this_unit:
continue
other_unit_times = (units[other_unit].times[:])/30.0
this_unit_counter, other_unit_counter = unit_similarity(this_unit_times, other_unit_times)
unit_distances[this_unit, other_unit] = 100.0*(float(this_unit_counter)/len(this_unit_times))
unit_distances[other_unit, this_unit] = 100.0*(float(other_unit_counter)/len(other_unit_times))
# If the similarity goes beyond the defined cutoff, write these unit numbers to file
if this_unit != other_unit and (unit_distances[this_unit, other_unit] > similarity_cutoff or unit_distances[other_unit, this_unit] > similarity_cutoff):
print(str(this_unit) + "\t" + str(other_unit), file = unit_similarity_violations)
# Print the progress to the window
print("Unit %i of %i completed" % (this_unit+1, len(units)))
print("Similarity calculation complete, results being saved to file")
print("==================")
'''
Tried to do everything in the numba compiled loop through vectorized numpy code - failed because of memory errors involved in broadcasting arrays
for this_unit in range(len(units)):
this_unit_times = units[this_unit].times[:]
for other_unit in range(len(units)):
if other_unit < this_unit:
continue
other_unit_times = units[other_unit].times[:]
# The outer keyword can be attached to any numpy ufunc to apply that operation to every element in x AND in y. So here we calculate diff[i, j] = this_unit_times[i] - other_unit_times[j] for all i and j
diff = np.abs(np.subtract.outer(this_unit_times, other_unit_times))
# Divide the diffs by 30 to convert to milliseconds - then check how many spikes have a spike in the other unit within 1 millisecond
diff_this_other = np.min(diff, axis = 1)/30.0
diff_other_this = np.min(diff, axis = 0)/30.0
unit_distances[this_unit, other_unit] = 100.0*len(np.where(diff_this_other <= 1.0)[0])/len(diff_this_other)
unit_distances[other_unit, this_unit] = 100.0*len(np.where(diff_other_this <= 1.0)[0])/len(diff_other_this)'''
# Make a node for storing unit distances under /sorted_units. First try to delete it, and pass if it exists
try:
hf5.remove_node('/unit_distances')
except:
pass
hf5.create_array('/', 'unit_distances', unit_distances)
hf5.close()
unit_similarity_violations.close()
Computing file changes ...