import logging logger = logging.getLogger(__name__) from scipy.sparse import coo_matrix import numpy as np from scipy.ndimage import gaussian_filter import torch from kilosort import spikedetect def bin_spikes(ops, st): """ for each batch, the spikes in that batch are binned to a 2D matrix by amplitude and depth """ # the bin edges are based on min and max of channel y positions ymin = ops['yc'].min() ymax = ops['yc'].max() dd = ops['binning_depth'] # binning width in depth # start 1um below the lowest channel dmin = ymin-1 # dmax is how many bins to use dmax = 1 + np.ceil((ymax-dmin)/dd).astype('int32') Nbatches = ops['Nbatches'] batch_id = st[:,4].copy() # always use 20 bins for amplitude binning F = np.zeros((Nbatches, dmax, 20)) for t in range(ops['Nbatches']): # consider only spikes from this batch ix = (batch_id==t).nonzero()[0] sst = st[ix] # their depth relative to the minimum dep = sst[:,1] - dmin # the amplitude binnning is logarithmic, goes from the Th_universal minimum value to 100. amp = np.log10(np.minimum(99, sst[:,2])) - np.log10(ops['Th_universal']) # amplitudes get normalized from 0 to 1 amp = amp / (np.log10(100)-np.log10(ops['Th_universal'])) # rows are divided by the vertical binning depth rows = (dep/dd).astype('int32') # columns are from 0 to 20 cols = (1e-5 + amp * 20).astype('int32') # for efficient binning, use sparse matrix computation in scipy cou = np.ones(len(ix)) M = coo_matrix((cou, (rows, cols)), (dmax, 20)) # the 2D histogram counts are transformed to logarithm F[t] = np.log2(1+M.todense()) # center of each vertical sampling bin ysamp = dmin + dd * np.arange(dmax) - dd/2 return F, ysamp def align_block2(F, ysamp, ops, device=torch.device('cuda')): Nbatches = ops['Nbatches'] # n is the maximum vertical shift allowed, in units of bins n = 15 dc = np.zeros((2*n+1, Nbatches)) dt = np.arange(-n,n+1,1) # batch fingerprints are mean subtracted along depth Fg = torch.from_numpy(F).to(device).float() Fg = Fg - Fg.mean(1).unsqueeze(1) # the template fingerprint is initialized with batch 300 if that exists F0 = Fg[np.minimum(300, Nbatches//2)] niter = 10 dall = np.zeros((niter, Nbatches)) # at each iteration, align each batch to the template fingerprint # Fg is incrementally modified, and cumulative shifts are accumulated over iterations for iter in range(niter): # for each vertical shift in the range -n to n, compute the dot product for t in range(len(dt)): Fs = torch.roll(Fg, dt[t], 1) dc[t] = (Fs * F0).mean(-1).mean(-1).cpu().numpy() # for all but the last iteration, align the batches if iter 1. """ if ops['nblocks']<1: ops['dshift'] = None logger.info('nblocks = 0, skipping drift correction') return ops, None # the first step is to extract all spikes using the universal templates st, _, ops = spikedetect.run( ops, bfile, device=device, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose ) # spikes are binned by amplitude and y-position to construct a "fingerprint" for each batch F, ysamp = bin_spikes(ops, st) # the fingerprints are iteratively aligned to each other vertically imin, yblk, _, _ = align_block2(F, ysamp, ops, device=device) # imin contains the shifts for each batch, in units of discrete bins # multiply back with binning_depth for microns dshift = imin * ops['binning_depth'] # we save the variables needed for drift correction during the data preprocessing step ops['yblk'] = yblk ops['dshift'] = dshift xp = np.vstack((ops['xc'],ops['yc'])).T # for interpolation, we precompute a radial kernel based on distances between sites Kxx = kernel2D(xp, xp, ops['sig_interp']) Kxx = torch.from_numpy(Kxx).to(device) # a small constant is added to the diagonal for stability of the matrix inversion ops['iKxx'] = torch.linalg.inv(Kxx + 0.01 * torch.eye(Kxx.shape[0], device=device)) return ops, st