Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://doi.org/10.5281/zenodo.3597474
27 August 2025, 14:32:56 UTC
  • Code
  • Branches (0)
  • Releases (7)
  • Visits
    • Branches
    • Releases
      • 7
      • 7
      • 6
      • 5
      • 4
      • 3
      • 2
      • 1
    • b6d380b
    • /
    • MouseLand-Kilosort-8f396c7
    • /
    • kilosort
    • /
    • template_matching.py
    Raw File Download

    To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
    Select below a type of object currently browsed in order to display its associated SWHID and permalink.

    • content
    • directory
    • snapshot
    • release
    origin badgecontent badge
    swh:1:cnt:dca96c595b9f98e9f9d277f1f882eeefb366a178
    origin badgedirectory badge
    swh:1:dir:5a94e6584e2d42b7e8e4bdb91a2e2f831f22c766
    origin badgesnapshot badge
    swh:1:snp:cde65436e19c7a8c490b366892685f110e246dcf
    origin badgerelease badge
    swh:1:rel:4aa70d8c6bf4aecf2297faaca70d7dba2d9e7b48

    This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
    Select below a type of object currently browsed in order to generate citations for them.

    • content
    • directory
    • snapshot
    • release
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    template_matching.py
    import logging
    
    import numpy as np
    import torch 
    from torch.nn.functional import conv1d, max_pool2d, max_pool1d
    from tqdm import tqdm
    
    from kilosort import CCG
    from kilosort.utils import log_performance
    
    logger = logging.getLogger(__name__)
    
    
    def prepare_extract(xc, yc, U, nC, position_limit, device=torch.device('cuda')):
        """Identify desired channels based on distances and template norms.
        
        Parameters
        ----------
        xc : np.ndarray
            X-coordinates of contact positions on probe.
        yc : np.ndarray
            Y-coordinates of contact positions on probe.
        U : torch.Tensor
            TODO
        nC : int
            Number of nearest channels to use.
        position_limit : float
            Max distance (in microns) between channels that are used to estimate
            spike positions in `postprocessing.compute_spike_positions`.
    
        Returns
        -------
        iCC : np.ndarray
            For each channel, indices of nC nearest channels.
        iCC_mask : np.ndarray
            For each channel, a 1 if the channel is within 100um and a 0 otherwise.
            Used to control spike position estimate in post-processing.
        iU : torch.Tensor
            For each template, index of channel with greatest norm.
        Ucc : torch.Tensor
            For each template, spatial PC features corresponding to iCC.
        
        """
        ds = (xc - xc[:, np.newaxis])**2 +  (yc - yc[:, np.newaxis])**2 
        iCC = np.argsort(ds, 0)[:nC]
        iCC = torch.from_numpy(iCC).to(device)
        iCC_mask = np.sort(ds, 0)[:nC]
        iCC_mask = iCC_mask < position_limit**2
        iCC_mask = torch.from_numpy(iCC_mask).to(device)
        iU = torch.argmax((U**2).sum(1), -1)
        Ucc = U[torch.arange(U.shape[0]),:,iCC[:,iU]]
    
        return iCC, iCC_mask, iU, Ucc
    
    
    def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
        nC = ops['settings']['nearest_chans']
        position_limit = ops['settings']['position_limit']
        iCC, iCC_mask, iU, Ucc = prepare_extract(
            ops['xc'], ops['yc'], U, nC, position_limit, device=device
            )
        ops['iCC'] = iCC
        ops['iCC_mask'] = iCC_mask
        ops['iU'] = iU
        nt = ops['nt']
        
        tiwave = torch.arange(-(nt//2), nt//2+1, device=device) 
        ctc = prepare_matching(ops, U)
        st = np.zeros((10**6, 3), 'float64')
        tF  = torch.zeros((10**6, nC , ops['settings']['n_pcs']))
        k = 0
        prog = tqdm(
            np.arange(bfile.n_batches, dtype=np.int64),
            miniters=200 if progress_bar else None, 
            mininterval=60 if progress_bar else None
            )
        
        try:
            for ibatch in prog:
                if ibatch % 100 == 0:
                    log_performance(logger, 'debug', f'Batch {ibatch}')
    
                X = bfile.padded_batch_to_torch(ibatch, ops)
                stt, amps, th_amps, Xres = run_matching(ops, X, U, ctc, device=device)
                xfeat = Xres[iCC[:, iU[stt[:,1:2]]],stt[:,:1] + tiwave] @ ops['wPCA'].T
                xfeat += amps * Ucc[:,stt[:,1]]
    
                if ibatch == 0:
                    # Can sometimes get negative spike times for first batch since
                    # we're aligning to nt0min, not nt//2, but these should be discarded.
                    neg_spikes = (stt[:,0] - nt - nt//2 + ops['nt0min']) < 0
                    stt = stt[~neg_spikes,:]
                    xfeat = xfeat[:,~neg_spikes,:]
                    amps = amps[~neg_spikes,:]
                    th_amps = th_amps[~neg_spikes,:]
    
                nsp = len(stt) 
                if k+nsp>st.shape[0]:                     
                    st = np.concatenate((st, np.zeros_like(st)), 0)
                    tF  = torch.cat((tF,  torch.zeros_like(tF)), 0)
    
                stt = stt.double()
                st[k:k+nsp,0] = ((stt[:,0]-nt) + ibatch * (ops['batch_size'])).cpu().numpy() - nt//2 + ops['nt0min']
                st[k:k+nsp,1] = stt[:,1].cpu().numpy()
                st[k:k+nsp,2] = th_amps.cpu().numpy().squeeze()
                
                tF[k:k+nsp]  = xfeat.transpose(0,1).cpu()
    
                k+= nsp
                
                if progress_bar is not None:
                    progress_bar.emit(int((ibatch+1) / bfile.n_batches * 100))
        except:
            logger.exception(f'Error in template_matching.extract on batch {ibatch}')
            logger.debug(f'X shape: {X.shape}')
            logger.debug(f'stt shape: {stt.shape}')
            raise
    
        log_performance(logger, 'debug', f'Batch {ibatch}')
    
        isort = np.argsort(st[:k,0])
        st = st[isort]
        tF = tF[isort]
    
        return st, tF, ops
    
    
    def align_U(U, ops, device=torch.device('cuda')):
        Uex = torch.einsum('xyz, zt -> xty', U.to(device), ops['wPCA'])
        X = Uex.reshape(-1, ops['Nchan']).T
        X = conv1d(X.unsqueeze(1), ops['wTEMP'].unsqueeze(1), padding=ops['nt']//2)
        Xmax = X.abs().max(0)[0].max(0)[0].reshape(-1, ops['nt'])
        imax = torch.argmax(Xmax, 1)
    
        Unew = Uex.clone() 
        for j in range(ops['nt']):
            ix = imax==j
            Unew[ix] = torch.roll(Unew[ix], ops['nt']//2 - j, -2)
        Unew = torch.einsum('xty, zt -> xzy', Unew, ops['wPCA'])#.transpose(1,2).cpu()
        return Unew, imax
    
    
    def postprocess_templates(Wall, ops, clu, st, tF, device=torch.device('cuda')):
        Wall2, _ = align_U(Wall, ops, device=device)
        #Wall3, _= remove_duplicates(ops, Wall2)
        Wall3, _, _, _, _ = merging_function(
            ops, Wall2.transpose(1,2), clu, st, tF,
            0.9, 'mu', check_dt=False, device=device
            )
        Wall3 = Wall3.transpose(1,2).to(device)
        return Wall3
    
    
    def prepare_matching(ops, U):
        nt = ops['nt']
        W = ops['wPCA'].contiguous()
        WtW = conv1d(W.reshape(-1, 1,nt), W.reshape(-1, 1 ,nt), padding = nt) 
        WtW = torch.flip(WtW, [2,])
    
        #mu = (U**2).sum(-1).sum(-1)**.5
        #U2 = U / mu.unsqueeze(-1).unsqueeze(-1)
    
        UtU = torch.einsum('ikl, jml -> ijkm',  U, U)
        ctc = torch.einsum('ijkm, kml -> ijl', UtU, WtW)
    
        return ctc
    
    
    def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
        Th = ops['Th_learned']
        nt = ops['nt']
        max_peels = ops['max_peels']
        W = ops['wPCA'].contiguous()
    
        nm = (U**2).sum(-1).sum(-1)
        #mu = nm**.5 
        #U2 = U / mu.unsqueeze(-1).unsqueeze(-1)
    
        B = conv1d(X.unsqueeze(1), W.unsqueeze(1), padding=nt//2)
        B = torch.einsum('ijk, kjl -> il', U, B)
    
        trange = torch.arange(-nt, nt+1, device=device) 
        tiwave = torch.arange(-(nt//2), nt//2+1, device=device) 
    
        st = torch.zeros((100000,2), dtype = torch.int64, device = device)
        amps = torch.zeros((100000,1), dtype = torch.float, device = device)
        th_amps = torch.zeros((100000,1), dtype = torch.float, device = device)
        k = 0
    
        Xres = X.clone()
        lam = 20
    
        for t in range(max_peels):
            # Cf = 2 * B - nm.unsqueeze(-1) 
            # Cf is shape (n_units, n_times)
            Cf = torch.relu(B)**2 /nm.unsqueeze(-1)
            #a = 1 + lam
            #b = torch.relu(B) + lam * mu.unsqueeze(-1)
            #Cf = b**2 / a - lam * mu.unsqueeze(-1)**2
    
            Cf[:, :nt] = 0
            Cf[:, -nt:] = 0
    
            Cfmax, imax = torch.max(Cf, 0)
            Cmax  = max_pool1d(Cfmax.unsqueeze(0).unsqueeze(0), (2*nt+1), stride=1, padding=(nt))
    
            #print(Cfmax.shape)
            #import pdb; pdb.set_trace()
            cnd1 = Cmax[0,0] > Th**2
            cnd2 = torch.abs(Cmax[0,0] - Cfmax) < 1e-9
            xs = torch.nonzero(cnd1 * cnd2)
    
            
            if len(xs)==0:
                #print('iter %d'%t)
                break
    
            iX = xs[:,:1]
            iY = imax[iX]
    
            #isort = torch.sort(iX)
    
            nsp = len(iX)
            st[k:k+nsp, 0] = iX[:,0]
            st[k:k+nsp, 1] = iY[:,0]
            amps[k:k+nsp] = B[iY,iX] / nm[iY]
            amp = amps[k:k+nsp]
            th_amps[k:k+nsp] = Cmax[0, 0, iX[:,0], None]**.5
    
            k+= nsp
    
            #amp = B[iY,iX] 
    
            n = 2
            for j in range(n):
                Xres[:, iX[j::n] + tiwave]  -= amp[j::n] * torch.einsum('ijk, jl -> kil', U[iY[j::n,0]], W)
                B[   :, iX[j::n] + trange]  -= amp[j::n] * ctc[:,iY[j::n,0],:]
    
        st = st[:k]
        amps = amps[:k]
        th_amps = th_amps[:k]
    
        return  st, amps, th_amps, Xres
    
    
    def merging_function(ops, Wall, clu, st, tF, r_thresh=0.5, mode='ccg', check_dt=True,
                         device=torch.device('cuda')):
        clu2 = clu.copy()
        clu_unq, ns = np.unique(clu2, return_counts = True)
    
        Ww = Wall.to(device)
        NN = len(Ww)
    
        isort = np.argsort(ns)[::-1]
    
        is_merged = np.zeros(NN, 'bool')
        is_good = np.zeros(NN,)
    
        acg_threshold = ops['settings']['acg_threshold']
        ccg_threshold = ops['settings']['ccg_threshold']
        if mode == 'ccg':
            is_ref, est_contam_rate = CCG.refract(clu, st[:,0]/ops['fs'],
                                                  acg_threshold=acg_threshold,
                                                  ccg_threshold=ccg_threshold)
    
        nt = ops['nt']
        W = ops['wPCA'].contiguous()
        WtW = conv1d(W.reshape(-1, 1,nt), W.reshape(-1, 1 ,nt), padding = nt) 
        WtW = torch.flip(WtW, [2,])
    
        t = 0
        nmerge = 0
        while t<NN:
            #if t%100==0:
                #print(t, nmerge)
    
            kk = clu_unq[isort[t]]
    
            if (mode == 'ccg') and is_ref[kk]==0:
                t += 1
                continue
    
            if is_merged[kk]:            
                t += 1
                continue
    
            mu = (Ww**2).sum((1,2), keepdims=True)**.5
            Wnorm = Ww / (1e-6 + mu)
    
            UtU = torch.einsum('lk, jlm -> jkm',  Wnorm[kk], Wnorm)
            ctc = torch.einsum('jkm, kml -> jl', UtU, WtW)
    
            cmax, imax = ctc.max(1)
            cmax[kk] = 0
    
            jsort = np.argsort(cmax.cpu().numpy())[::-1]
    
            if mode == 'ccg':
                st0 = st[:,0][clu2==kk] / ops['fs']
            
            is_ccg  = 0
            for j in range(NN):
                jj = jsort[j]
                if cmax[jj] < r_thresh:
                    break
                # compare with CCG
                if mode == 'ccg':
                    st1 = st[:,0][clu2==jj] / ops['fs']
                    _, is_ccg, _ = CCG.check_CCG(st0, st1, acg_threshold=acg_threshold,
                                                 ccg_threshold=ccg_threshold)        
                else:
                    dmu = 2 * (mu[kk] - mu[jj]) / (mu[kk] + mu[jj])
                    is_ccg = dmu.abs() < 0.2
    
                if is_ccg:
                    is_merged[jj] = 1
                    dt = (imax[kk] -imax[jj]).item()
                    if dt != 0 and check_dt:
                        # Get spike indices for cluster jj
                        idx = (clu2 == jj)
                        # Update tF and Wall with shifted features
                        tF, Wall = roll_features(W, tF, Ww, idx, jj, dt)
                        # Shift spike times
                        st[idx,0] -= dt
                    
                    Ww[kk] = ns[kk]/(ns[kk]+ns[jj]) * Ww[kk] + ns[jj]/(ns[kk]+ns[jj]) * Ww[jj]            
                    Ww[jj] = 0
                    ns[kk] += ns[jj]
                    ns[jj] = 0
                    clu2[clu2==jj] = kk            
    
                    break
    
            if is_ccg==0:            
                t +=1    
            else:                
                nmerge+=1
        
        imap = np.cumsum((~is_merged).astype('int32')) - 1
        if imap.size > 0:
            # Otherwise, everything has been merged into a single cluster
            clu2 = imap[clu2]
    
        Ww = Ww[~is_merged]
    
        if mode == 'ccg':
            is_ref = is_ref[~is_merged]
        else:
            is_ref = None
    
        sorted_idx = np.argsort(st[:,0])
        st = np.take_along_axis(st, sorted_idx[..., np.newaxis], axis=0)
        clu2 = clu2[sorted_idx]
        tensor_idx = torch.from_numpy(sorted_idx)
        tF = tF[tensor_idx]
    
        return Ww.cpu(), clu2, is_ref, st, tF
    
    
    def roll_features(wPCA, tF, Wall, spike_idx, clust_idx, dt):
        W = wPCA.cpu()
        # Project from PC space back to sample time, shift by dt
        feats = torch.roll(tF[spike_idx] @ W, shifts=dt, dims=2)
        temps = torch.roll(Wall[clust_idx:clust_idx+1] @ wPCA, shifts=dt, dims=2)
    
        # For values that "rolled over the edge," set equal to next closest bin
        if dt > 0:
            feats[:,:,:dt] = feats[:,:,dt].unsqueeze(-1)
            temps[:,:,:dt] = temps[:,:,dt].unsqueeze(-1)
        elif dt < 0:
            feats[:,:,dt:] = feats[:,:,dt-1].unsqueeze(-1)
            temps[:,:,dt:] = temps[:,:,dt-1].unsqueeze(-1)
    
        # Project back to PC space and update tF
        tF[spike_idx] = feats @ W.T
        Wall[clust_idx] = temps @ wPCA.T
    
        return tF, Wall
    

    back to top

    Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
    The source code of Software Heritage itself is available on our development forge.
    The source code files archived by Software Heritage are available under their own copyright and licenses.
    Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API