Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://doi.org/10.5281/zenodo.3597474
27 August 2025, 14:32:56 UTC
  • Code
  • Branches (0)
  • Releases (7)
  • Visits
    • Branches
    • Releases
      • 7
      • 7
      • 6
      • 5
      • 4
      • 3
      • 2
      • 1
    • b6d380b
    • /
    • MouseLand-Kilosort-8f396c7
    • /
    • kilosort
    • /
    • bench.py
    Raw File Download

    To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
    Select below a type of object currently browsed in order to display its associated SWHID and permalink.

    • content
    • directory
    • snapshot
    • release
    origin badgecontent badge
    swh:1:cnt:3f445d914d8c53c28ee00edf0b7b6c9dcb1228a1
    origin badgedirectory badge
    swh:1:dir:5a94e6584e2d42b7e8e4bdb91a2e2f831f22c766
    origin badgesnapshot badge
    swh:1:snp:cde65436e19c7a8c490b366892685f110e246dcf
    origin badgerelease badge
    swh:1:rel:4aa70d8c6bf4aecf2297faaca70d7dba2d9e7b48

    This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
    Select below a type of object currently browsed in order to generate citations for them.

    • content
    • directory
    • snapshot
    • release
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    bench.py
    import os
    
    import torch
    import numpy as np
    from torch.fft import fft, ifft, fftshift
    from scipy.interpolate import interp1d
    from tqdm import trange
    
    from kilosort import preprocessing
    from kilosort.io import BinaryFiltered
    
    
    if torch.cuda.is_available():
        dev = torch.device('cuda:0')
    else:
        dev = torch.device('cpu')
    
    
    def get_drift_matrix(ops, dshift):
        """ for a given dshift drift, computes the linear drift matrix for interpolation
        """
    
        # first, interpolate drifts to every channel
        yblk = ops['yblk']
        finterp = interp1d(yblk, dshift, fill_value="extrapolate", kind = 'linear')
        shifts = finterp(ops['probe']['yc'])
    
        # compute coordinates of desired interpolation
        xp = np.vstack((ops['probe']['xc'],ops['probe']['yc'])).T
        yp = xp.copy()
        yp[:,1] -= shifts
    
        xp = torch.from_numpy(xp).to(dev)
        yp = torch.from_numpy(yp).to(dev)
    
        # run interpolated to obtain a kernel
        Kyx = preprocessing.kernel2D_torch(yp, xp, ops['settings']['sig_interp'])
        
        # multiply with precomputed kernel matrix of original channels
        M = Kyx @ ops['iKxx']
    
        return M
    
    
    def avg_wav(filename, Wsub, nn, ops, ibatch, st_i, clu, Nfilt):
        fs = ops['fs']
        n_chan_bin = ops['n_chan_bin']
        nt = ops['nt']
        NT = ops['batch_size']
        if isinstance(ops['wPCA'], np.ndarray):
            ops['wPCA'] = torch.from_numpy(ops['wPCA']).to(dev)
    
        bf = BinaryFiltered(filename, n_chan_bin=n_chan_bin, fs=fs, NT=NT, nt=nt,
                            hp_filter=ops['fwav'], whiten_mat=ops['Wrot'],
                            dshift=ops['dshift'], chan_map=ops['chanMap'])
        X = bf.padded_batch_to_torch(ibatch, ops)
    
        ix = (st_i - NT * ibatch) * (st_i - NT * (1+ibatch)) < 0
        st_sub = st_i[ix] + nt - NT * ibatch
        st_sub = torch.from_numpy(st_sub).to(dev)
    
        tiwave = torch.arange(nt, device=dev)
        xsub = X[:, st_sub.unsqueeze(-1) + tiwave] @ ops['wPCA'].T
        nsp = xsub.shape[1]
        M = torch.zeros((Nfilt, nsp), dtype=torch.float, device = dev)
        M[clu[ix], torch.arange(nsp)] = 1
        
        Wsub += torch.einsum('ijk, lj -> lki', xsub, M)
        nn += M.sum(1)
    
        return Wsub, nn
    
    
    def clu_ypos(filename, ops, st_i, clu):
        Nfilt = clu.max()+1
        Wsub = torch.zeros((Nfilt, ops['n_pcs'], ops['Nchan']), device = dev)
        nn   = torch.zeros((Nfilt, ), device = dev)
    
        for ibatch in range(0, ops['Nbatches'], 10):
            Wsub, nn = avg_wav(filename, Wsub, nn, ops, ibatch, st_i, clu, Nfilt)
    
        Wsub = Wsub / nn.unsqueeze(-1).unsqueeze(-1)
        #Wsub = Wsub / nn.unsqueeze(-1).unsqueeze(-1)
        Wsub = Wsub.cpu().numpy()
    
        ichan = np.argmax((Wsub**2).sum(1), -1)
        yclu = ops['yc'][ichan]
    
        return yclu, Wsub
    
    
    def nmatch(ss0, ss, dt=6):
        i = 0
        j = 0
        n0 = 0
    
        ntmax = len(ss0)
        is_matched  = np.zeros(len(ss), 'bool')
        is_matched0 = np.zeros(len(ss0), 'bool')
    
        while i<len(ss):
            while j+1<ntmax and ss0[j+1] <= ss[i]+dt:
                j += 1
    
            if np.abs(ss0[j] - ss[i]) <=dt:
                n0 += 1
                is_matched[i] = 1
                is_matched0[j] = 1
    
            i+= 1
        return n0, is_matched, is_matched0
    
    
    def match_neuron(kk, clu, yclu, st_i, clu0, yclu0, st0_i, n_check=20, dt=6):
        ss = st_i[clu==kk]
        isort = np.argsort(np.abs(yclu[kk] - yclu0))
        fmax = 0
        miss = 0
        fpos = 0
        best_ind = isort[0]
        matched_all = -1 * np.ones((n_check,))
        top_inds = -1 * np.ones((n_check,), "int")
        ntest = min(len(isort), n_check)
        top_inds[:ntest] = isort[:ntest]
        for j in range(ntest):
            ss0 = st0_i[clu0==isort[j]]
    
            if len(ss0) ==0:
                continue
            
            n0, is_matched, is_matched0 = nmatch(ss0, ss, dt=dt)
    
            #fmax_new = n0 / (len(ss) + len(ss0) - n0)
            fmax_new = 1 - np.maximum(0, 1 - n0/len(ss)) - np.maximum(0, 1 - n0/len(ss0))
            matched_all[j] = n0
    
            if fmax_new > fmax:
                miss = np.maximum(0, 1 - n0/len(ss))
                fpos = np.maximum(0, 1 - n0/len(ss0))
    
                fmax = fmax_new
                best_ind = isort[j]
    
        return fmax, miss, fpos, best_ind, matched_all, top_inds
    
    
    def compare_recordings(st_gt, clu_gt, yclu_gt, st_new, clu_new, yclu_new):
        NN = len(yclu_gt)
    
        n_check = 20
        fmax = np.zeros(NN,)
        matched_all = np.zeros((NN, n_check))
        fmiss = np.zeros(NN,)
        fpos = np.zeros(NN,)
        best_ind = np.zeros(NN, "int")
        top_inds = np.zeros((NN, n_check), "int")
    
        for kk in trange(NN):
            out = match_neuron(kk, clu_gt, yclu_gt, st_gt, clu_new, yclu_new, st_new, n_check=n_check)
            fmax[kk], fmiss[kk], fpos[kk], best_ind[kk], matched_all[kk], top_inds[kk] = out
    
        return fmax, fmiss, fpos, best_ind, matched_all, top_inds
    
    
    def load_GT(filename, ops, gt_path, toff=20, nmax=600, tmax=None):
        #gt_path = os.path.join(ops['data_folder'] , "sim.imec0.ap_params.npz")
        dd = np.load(gt_path, allow_pickle = True)
    
        st_gt = dd['st'].astype('int64')
        clu_gt = dd['cl'].astype('int64')
        
        if tmax is not None:
            imax = int(tmax * ops['fs'])
            idx = st_gt < imax
            st_gt = st_gt[idx]
            clu_gt = clu_gt[idx]
    
        ix = clu_gt<nmax
        st_gt  = st_gt[ix]
        clu_gt = clu_gt[ix]
    
        yclu_gt, Wsub = clu_ypos(filename, ops, st_gt - toff, clu_gt.astype('int64'))
        mu_gt = (Wsub**2).sum((1,2))**.5
    
        unq_clu, nsp = np.unique(clu_gt, return_counts = True)
        if np.abs(np.diff(unq_clu) - 1).sum()>1e-5:
            print('error, some ground truth units are missing')
    
        return st_gt, clu_gt, yclu_gt, mu_gt, Wsub, nsp
    
    
    def convert_ks_output(ops, st, clu, toff = 20):
        st = st[:,0].astype('int64')        
        yclu, Wsub    = clu_ypos(ops, st-toff, clu)
    
        return st, clu, yclu, Wsub
    
    
    def load_phy(filename, fpath, ops):
        st_new  = np.load(os.path.join(fpath,  "spike_times.npy")).astype('int64')
        try:
            clu_new = np.load(os.path.join(fpath ,"cluster_times.npy")).astype('int64')
        except:
            clu_new = np.load(os.path.join(fpath ,"spike_clusters.npy")).astype('int64')
        if st_new.ndim==2:
            st_new = st_new[:,0]
        if clu_new.ndim==2:
            clu_new = clu_new[:,0]
    
        yclu_new, Wsub = clu_ypos(filename, ops, st_new - 20, clu_new)
    
        return st_new, clu_new, yclu_new, Wsub
    

    back to top

    Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
    The source code of Software Heritage itself is available on our development forge.
    The source code files archived by Software Heritage are available under their own copyright and licenses.
    Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API