Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://doi.org/10.5281/zenodo.3597474
27 August 2025, 14:32:56 UTC
  • Code
  • Branches (0)
  • Releases (7)
  • Visits
    • Branches
    • Releases
      • 7
      • 7
      • 6
      • 5
      • 4
      • 3
      • 2
      • 1
    • b6d380b
    • /
    • MouseLand-Kilosort-8f396c7
    • /
    • kilosort
    • /
    • spikedetect.py
    Raw File Download

    To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
    Select below a type of object currently browsed in order to display its associated SWHID and permalink.

    • content
    • directory
    • snapshot
    • release
    origin badgecontent badge
    swh:1:cnt:3efc1afeb4e694439bcd0c689d492abd42cacf5f
    origin badgedirectory badge
    swh:1:dir:5a94e6584e2d42b7e8e4bdb91a2e2f831f22c766
    origin badgesnapshot badge
    swh:1:snp:cde65436e19c7a8c490b366892685f110e246dcf
    origin badgerelease badge
    swh:1:rel:4aa70d8c6bf4aecf2297faaca70d7dba2d9e7b48

    This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
    Select below a type of object currently browsed in order to generate citations for them.

    • content
    • directory
    • snapshot
    • release
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    spikedetect.py
    import os
    import gc
    import logging
    import warnings
    logger = logging.getLogger(__name__)
    
    from torch.nn.functional import max_pool2d, avg_pool2d, conv1d, max_pool1d
    import numpy as np
    import torch
    from sklearn.cluster import KMeans
    from sklearn.decomposition import TruncatedSVD
    from tqdm import tqdm
    
    from kilosort.utils import template_path, log_performance
    
    
    def my_max2d(X, dt):
        Xmax = max_pool2d(
            X.unsqueeze(0), [2*dt[0]+1, 2*dt[1]+1],
            stride=[1,1], padding=[dt[0],dt[1]]
            )    
        return Xmax[0]
    
    def my_sum2d(X, dt):
        Xsum = avg_pool2d(
            X.unsqueeze(0), [2*dt[0]+1, 2*dt[1]+1],
            stride=[1,1], padding=[dt[0],dt[1]]
            )    
        Xsum *= (2*dt[0]+1) * (2*dt[1]+1)
        return Xsum[0]
    
    def extract_snippets(X, nt, twav_min, Th_single_ch, loc_range=[4,5],
                         long_range=[6,30], device=torch.device('cuda')):
        Xabs   = X.abs()
        Xmax   = my_max2d(Xabs, loc_range)
        ispeak = torch.logical_and(Xmax==Xabs, Xabs > Th_single_ch).float()
    
        ispeak_sum  = my_sum2d(ispeak, long_range)
        is_peak_iso = ((ispeak_sum==1) * (ispeak==1))
    
        is_peak_iso[:, :nt] = 0
        is_peak_iso[:, -nt:] = 0
    
        xy = is_peak_iso.nonzero()
    
        clips = X[xy[:,:1], xy[:,1:2] - twav_min + torch.arange(nt, device=device)]
    
        return clips
    
    def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25,
                           device=torch.device('cuda')):
    
        clips = np.zeros((500000,nt), 'float32')
        i = 0
        for j in range(0, bfile.n_batches, nskip):
            X = bfile.padded_batch_to_torch(j, ops)
            
            clips_new = extract_snippets(X, nt=nt, twav_min=twav_min,
                                         Th_single_ch=Th_single_ch, device=device)
    
            nnew = len(clips_new)
    
            if i+nnew>clips.shape[0]:
                break
    
            clips[i:i+nnew] = clips_new.cpu().numpy()
            i+= nnew 
    
        clips = clips[:i]
        clips /= (clips**2).sum(1, keepdims=True)**.5
    
        model = TruncatedSVD(n_components=ops['settings']['n_pcs']).fit(clips)
        wPCA = torch.from_numpy(model.components_).to(device).float()
    
        with warnings.catch_warnings():
            msg = 'KMeans is known to have a memory leak on Windows with MKL'
            warnings.filterwarnings("ignore", message=msg)
            # Prevents memory leak for KMeans when using MKL on Windows
            nthread = os.environ.get('OMP_NUM_THREADS')
            os.environ['OMP_NUM_THREADS'] = '7'
            model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
            wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float()
            wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
            if nthread is not None:
                os.environ['OMP_NUM_THREADS'] = nthread
            else:
                os.environ.pop('OMP_NUM_THREADS')
    
        return wPCA, wTEMP
    
    def get_waves(ops, device=torch.device('cuda')):
        dd = np.load(template_path())
        wTEMP = torch.from_numpy(dd['wTEMP']).to(device)
        wPCA = torch.from_numpy(dd['wPCA']).to(device)
        return wPCA, wTEMP
    
    def template_centers(ops):
        shank_idx = ops['kcoords']
        xc = ops['xc']
        yc = ops['yc']
        dmin = ops['settings']['dmin']
        if dmin is None:
            # Try to determine a good value automatically based on contact positions.
            y_uniq = np.unique(yc)
            if y_uniq.size == 1:
                dmin = 1
            else:
                dmin = np.median(np.diff(np.unique(y_uniq)))
        ops['dmin'] = dmin
        ops['dminx'] = dminx = ops['settings']['dminx']
    
        # Iteratively determine template placement for each shank separately.
        yup = np.array([])
        xup = np.array([])
        for i in np.unique(shank_idx):
            xc_i = xc[shank_idx == i]
            yc_i = yc[shank_idx == i]
            xmin, xmax, ymin, ymax = xc_i.min(), xc_i.max(), yc_i.min(), yc_i.max()
    
            yup = np.concatenate([yup, np.arange(ymin, ymax+.00001, dmin/2)])
            nx = np.round((xmax - xmin) / (dminx/2)) + 1
            xup = np.concatenate([xup, np.linspace(xmin, xmax, int(nx))])
    
        ops['yup'] = np.unique(yup)
        ops['xup'] = np.unique(xup)
    
        return ops
    
    
    def template_match(X, ops, iC, iC2, weigh, device=torch.device('cuda')):
        nt = ops['nt']
        nt0 = ops['settings']['nt0min']
        nk = ops['settings']['n_templates']
        NT = X.shape[-1]
        Nfilt = iC.shape[1]
        niter = 40
        nb = (NT-1)//niter+1
    
        W = ops['wTEMP'].unsqueeze(1)
        B = conv1d(X.unsqueeze(1), W, padding=nt//2)
        As    = torch.zeros((Nfilt, NT), device=device)
        Amaxs = torch.zeros((Nfilt, NT), device=device)
        imaxs = torch.zeros((Nfilt, NT), dtype = torch.int64, device=device)
        ti = torch.arange(Nfilt, device = device)
        tj = torch.arange(nb, device = device)
    
        for t in range(niter):
            A = torch.einsum('ijk, jklm-> iklm', weigh, B[iC,:, nb*t:nb*(t+1)])        
            A = A.transpose(1,2)
            A = A.reshape(-1, Nfilt, A.shape[-1])
            
            #Aa, imax = torch.max(A, 0) 
            Aa, imax = torch.max(A.abs(), 0)
            imax = (1+imax) * A[imax, ti.unsqueeze(-1), tj[:A.shape[-1]]].sign()
    
            As[:, nb*t:nb*(t+1)] = Aa
            imaxs[:, nb*t:nb*(t+1)] = imax
            Amax = torch.max(Aa[iC2], 0)[0]
            Amaxs[:, nb*t:nb*(t+1)] = Amax
    
        Amaxs[:,:nt] = 0
        Amaxs[:,-nt:] = 0
        Amaxs  = max_pool1d(Amaxs.unsqueeze(0), (2*nt0+1), stride = 1, padding = nt0).squeeze(0)
        xy = torch.logical_and(Amaxs==As, As > ops['Th_universal']).nonzero()
        imax = imaxs[xy[:,0], xy[:,1]]
        amp = As[xy[:,0], xy[:,1]]
    
        ssign = imax.sign()
        imax = imax.abs()-1
        adist = B[iC[:, xy[:,0]], imax%nk, xy[:,1]] * ssign
    
        #adist = B[iC[:, xy[:,0]], imax%nk, xy[:,1]] 
        
        #xy[:,1] -= nt
        return xy, imax, amp, adist
    
    
    def nearest_chans(ys, yc, xs, xc, nC, device=torch.device('cuda')):
        ds = (ys - yc[:,np.newaxis])**2 + (xs - xc[:,np.newaxis])**2
        iC = np.argsort(ds, 0)[:nC]
        iC = torch.from_numpy(iC).to(device)
        ds = np.sort(ds, 0)[:nC]
    
        return iC, ds
    
    
    def yweighted(yc, iC, adist, xy, device=torch.device('cuda')):    
    
        yy = torch.from_numpy(yc).to(device)[iC]
        cF0 = torch.nn.functional.relu(adist)
        cF0 = cF0/cF0.sum(0)
    
        yct = (cF0 * yy[:,xy[:,0]]).sum(0)
        return yct
    
    def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
            clear_cache=False, verbose=False):        
        sig = ops['settings']['min_template_size']
        nsizes = ops['settings']['template_sizes']
        nb = ops['Nbatches']
    
        if ops['settings']['templates_from_data']:
            logger.info('Re-computing universal templates from data.')
            # Determine templates and PC features from data.
            ops['wPCA'], ops['wTEMP'] = extract_wPCA_wTEMP(
                ops, bfile, nt=ops['nt'], twav_min=ops['nt0min'], 
                Th_single_ch=ops['settings']['Th_single_ch'], nskip=25,
                device=device
                )
        else:
            logger.info('Using built-in universal templates.')
            # Use pre-computed templates.
            ops['wPCA'], ops['wTEMP'] = get_waves(ops, device=device)
    
        ops = template_centers(ops)
        [ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
        ys, xs = ys.flatten(), xs.flatten()
        logger.info(f'Number of universal templates: {ys.size}')
        xc, yc = ops['xc'], ops['yc']
    
        nC = ops['settings']['nearest_chans']
        nC2 = ops['settings']['nearest_templates']
        iC, ds = nearest_chans(ys, yc, xs, xc, nC, device=device)
    
        # Don't use templates that are too far away from nearest channel
        # (use square of max distance since ds are squared distances)
        igood = ds[0,:] <= ops['max_channel_distance']**2
        iC = iC[:,igood]
        ds = ds[:,igood]
        ys = ys[igood]
        xs = xs[igood]
        ops['ycup'], ops['xcup'] = ys, xs
    
        iC2, _ = nearest_chans(ys, ys, xs, xs, nC2, device=device)
    
        ds_torch = torch.from_numpy(ds).to(device).float()
        template_sizes = sig * (1+torch.arange(nsizes, device=device))
        weigh = torch.exp(-ds_torch.unsqueeze(-1) / template_sizes**2)
        weigh = torch.permute(weigh, (2, 0, 1)).contiguous()
        weigh = weigh / (weigh**2).sum(1).unsqueeze(1)**.5
    
        st = np.zeros((10**6, 6), 'float64')
        tF = np.zeros((10**6, nC , ops['settings']['n_pcs']), 'float32')
    
        k = 0
        nt = ops['nt']
        tarange = torch.arange(-(nt//2),nt//2+1, device = device)
        logger.info('Detecting spikes...')
        prog = tqdm(np.arange(bfile.n_batches), miniters=200 if progress_bar else None, 
                    mininterval=60 if progress_bar else None)
        # repeat performance log after every 10 minutes of data
        log_skip = int(600 / (ops['batch_size'] / ops['fs']))
        try:
            for ibatch in prog:
                if ibatch % log_skip == 0:
                    log_performance(logger, 'debug', f'Batch {ibatch} of {nb-1} ({100*(ibatch/nb):.1f}%)')
    
                X = bfile.padded_batch_to_torch(ibatch, ops)
                xy, imax, amp, adist = template_match(X, ops, iC, iC2, weigh, device=device)
                yct = yweighted(yc, iC, adist, xy, device=device)
                nsp = len(xy)
    
                if k+nsp>st.shape[0]:
                    st = np.concatenate((st, np.zeros_like(st)), 0)
                    tF = np.concatenate((tF, np.zeros_like(tF)), 0)
    
                xsub = X[iC[:,xy[:,:1]], xy[:,1:2] + tarange]
                xfeat = xsub @ ops['wPCA'].T
                tF[k:k+nsp] = xfeat.transpose(0,1).cpu().numpy()
    
                st[k:k+nsp,0] = ((xy[:,1].cpu().numpy()-nt)/ops['fs'] + ibatch * (ops['batch_size']/ops['fs']))
                st[k:k+nsp,1] = yct.cpu().numpy()
                st[k:k+nsp,2] = amp.cpu().numpy()
                st[k:k+nsp,3] = imax.cpu().numpy()
                st[k:k+nsp,4] = ibatch
                st[k:k+nsp,5] = xy[:,0].cpu().numpy()
    
                k = k + nsp
                if clear_cache:
                    gc.collect()
                    torch.cuda.empty_cache()
    
                if progress_bar is not None:
                    progress_bar.emit(int((ibatch+1) / bfile.n_batches * 100))
        except:
            logger.exception(f'Error in spikedetect.run on batch {ibatch}')
            try:
                logger.debug(f'X shape: {X.shape}')
                logger.debug(f'xy shape: {xy.shape}')
            except UnboundLocalError:
                # Error happened before one or both of these was assigned,
                # no need to raise an additional error for this.
                pass
            raise
                
        log_performance(logger, 'debug', f'Batch {ibatch} of {nb-1} ({100*(ibatch/nb):.1f}%)')
    
        st = st[:k]
        tF = tF[:k]
        ops['iC'] = iC
        ops['iC2'] = iC2
        ops['weigh'] = weigh
        return st, tF, ops
    

    back to top

    Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
    The source code of Software Heritage itself is available on our development forge.
    The source code files archived by Software Heritage are available under their own copyright and licenses.
    Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API