Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://doi.org/10.5281/zenodo.3597474
27 August 2025, 14:32:56 UTC
  • Code
  • Branches (0)
  • Releases (7)
  • Visits
    • Branches
    • Releases
      • 7
      • 7
      • 6
      • 5
      • 4
      • 3
      • 2
      • 1
    • b6d380b
    • /
    • MouseLand-Kilosort-8f396c7
    • /
    • kilosort
    • /
    • clustering_qr.py
    Raw File Download

    To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
    Select below a type of object currently browsed in order to display its associated SWHID and permalink.

    • content
    • directory
    • snapshot
    • release
    origin badgecontent badge
    swh:1:cnt:75edba6ad82ee078adad66b9021ccdc4a03cf182
    origin badgedirectory badge
    swh:1:dir:5a94e6584e2d42b7e8e4bdb91a2e2f831f22c766
    origin badgesnapshot badge
    swh:1:snp:cde65436e19c7a8c490b366892685f110e246dcf
    origin badgerelease badge
    swh:1:rel:4aa70d8c6bf4aecf2297faaca70d7dba2d9e7b48

    This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
    Select below a type of object currently browsed in order to generate citations for them.

    • content
    • directory
    • snapshot
    • release
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    clustering_qr.py
    import gc
    import logging
    
    import numpy as np
    import torch
    from torch import sparse_coo_tensor as coo
    from scipy.sparse import csr_matrix
    from scipy.ndimage import gaussian_filter
    from scipy.signal import find_peaks
    from scipy.cluster.vq import kmeans
    import faiss
    from tqdm import tqdm 
    
    from kilosort import hierarchical, swarmsplitter
    from kilosort.utils import log_performance
    
    logger = logging.getLogger(__name__)
    
    
    def neigh_mat(Xd, nskip=1, n_neigh=10, max_sub=25000):
        # Xd is spikes by PCA features in a local neighborhood
        # finding n_neigh neighbors of each spike to a subset of every nskip spike
    
        # n_samples is the number of spikes, dim is number of features
        n_samples, dim = Xd.shape
    
        # Downsample feature matrix by selecting every `nskip`-th spike
        Xsub = Xd[::nskip]
        n1 = Xsub.shape[0]
        # If the downsampled matrix is still larger than max_sub,
        # downsample it further by selecting `max_sub` evenly distributed spikes.
        if (max_sub is not None) and (n1 > max_sub):
            n2 = n1 - max_sub
            idx, rev_idx = subsample_idx(n1, n2)
            Xsub = Xsub[idx]
        else:
            rev_idx = None
    
        # n_nodes are the # subsampled spikes
        n_nodes = Xsub.shape[0]
    
        # search is much faster if array is contiguous
        Xd = np.ascontiguousarray(Xd)
        Xsub = np.ascontiguousarray(Xsub)
    
        # exact neighbor search ("brute force")
        # results is dn and kn
        # kn is n_spikes by n_neigh, contains integer indices into Xsub
        index = faiss.IndexFlatL2(dim)   # build the index
        index.add(Xsub)    # add vectors to the index
        _, kn = index.search(Xd, n_neigh)     # actual search
    
        # create sparse matrix version of kn with ones where the neighbors are
        # M is n_samples by n_nodes, adjacency matrix
        dexp = np.ones(kn.shape, np.float32)    
        rows = np.tile(np.arange(n_samples)[:, np.newaxis], (1, n_neigh)).flatten()
        M   = csr_matrix((dexp.flatten(), (rows, kn.flatten())),  # (data, (row,col))
                         (kn.shape[0], n_nodes))                  # (shape)
    
        # self connections are set to 0
        skip_idx = np.arange(0, n_samples, nskip)
        if rev_idx is not None:
            skip_idx = skip_idx[rev_idx]
        M[skip_idx, np.arange(n_nodes)] = 0
    
        return kn, M
    
    
    def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=torch.device('cuda')):
        n_spikes = kn.shape[0]
    
        ij = torch.vstack((rows_neigh.flatten(), isub[kn].flatten()))
        xN = coo(ij, tones2.flatten(), (n_spikes, nclust))
        xN = xN.to_dense()
    
        if lam > 0:
            tones = torch.ones(len(kj), device = device)
            tzeros = torch.zeros(len(kj), device = device)
            ij = torch.vstack((tzeros, isub))    
            kN = coo(ij, tones, (1, nclust))
        
            xN = xN - lam/m * (ki.unsqueeze(-1) * kN.to_dense()) 
        
        iclust = torch.argmax(xN, 1)
    
        return iclust
    
    
    def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.device('cuda')):
        n_neigh = kn.shape[1]
        cols = iclust.unsqueeze(-1).tile((1, n_neigh))
        iis = torch.vstack((kn.flatten(), cols.flatten()))
    
        xS = coo(iis, tones2.flatten(), (nsub, nclust))
        xS = xS.to_dense()
    
        if lam > 0:
            tones = torch.ones(len(ki), device = device)
            tzeros = torch.zeros(len(ki), device = device)
            ij = torch.vstack((tzeros, iclust))    
            kN = coo(ij, tones, (1, nclust))
            xS = xS - lam / m * (kj.unsqueeze(-1) * kN.to_dense())
    
        isub = torch.argmax(xS, 1)
        return isub
    
    
    def Mstats(M, device=torch.device('cuda')):
        m = M.sum()
        ki = np.array(M.sum(1)).flatten()
        kj = np.array(M.sum(0)).flatten()
        ki = m * ki/ki.sum()
        kj = m * kj/kj.sum()
    
        ki = torch.from_numpy(ki).to(device)
        kj = torch.from_numpy(kj).to(device)
        
        return m, ki, kj
    
    
    def cluster(Xd, iclust=None, kn=None, nskip=1, n_neigh=10, max_sub=25000,
                nclust=200, seed=1, niter=200, lam=0, device=torch.device('cuda'),
                verbose=False):    
    
        if kn is None:
            # kn: n_spikes by n_neigh with integer indices into the spike subset
            #     used for neighbor-finding determined by nskip.
            # M:  n_spikes by nsub, adjacency matrix representation of kn.
            kn, M = neigh_mat(Xd, nskip=nskip, n_neigh=n_neigh, max_sub=max_sub)
        m, ki, kj = Mstats(M, device=device)
    
        if verbose:
            logger.debug(f'ki: {ki.nbytes / (2**20):.2f} MB, shape: {ki.shape}')
            logger.debug(f'kj: {kj.nbytes / (2**20):.2f} MB, shape: {kj.shape}')
            log_performance(logger, header='clustering_qr.cluster, after Mstats')
    
        Xg = Xd.to(device)
        kn = torch.from_numpy(kn).to(device)
        n_spikes, n_neigh = kn.shape
        nsub = M.shape[1]  # number of spikes in neighbor-finding subset
    
        # rows_neigh, tones2 are just used to build properly formatted indices for
        # use by assign_isub and assign_iclust
        rows_neigh = torch.arange(n_spikes, device=device).unsqueeze(-1).tile((1,n_neigh))
        tones2 = torch.ones((n_spikes, n_neigh), device=device)
    
        if verbose:
            logger.debug(f'Xg: {Xg.nbytes / (2**20):.2f} MB, shape: {Xg.shape}')
            logger.debug(f'kn: {kn.nbytes / (2**20):.2f} MB, shape: {kn.shape}')
            logger.debug(f'rows_neigh: {rows_neigh.nbytes / (2**20):.2f} MB')
            logger.debug(f'tones2: {tones2.nbytes / (2**20):.2f} MB')
            log_performance(logger, header='clustering_qr.cluster, after var init')
    
        if iclust is None:
            iclust_init =  kmeans_plusplus(Xg, niter=nclust, seed=seed, 
                                           device=device, verbose=verbose)
            iclust = iclust_init.clone()
        else:
            iclust_init = iclust.clone()
            
        for t in range(niter):
            # given iclust, reassign isub
            isub = assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,
                               ki, kj,device=device)
            # given mu and isub, reassign iclust
            iclust = assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m,
                                   ki, kj, device=device)
            
        if verbose:
            logger.debug(f'isub: {isub.nbytes / (2**20):.2f} MB, shape: {isub.shape}')
            log_performance(logger, header='clustering_qr.cluster, after isub loop')
        
        _, iclust = torch.unique(iclust, return_inverse=True)    
        nclust = iclust.max() + 1
        isub = assign_isub(iclust, kn, tones2, nclust , nsub, lam, m,ki,kj, device=device)
    
        iclust = iclust.cpu().numpy()
        isub = isub.cpu().numpy()
    
        return iclust, isub, M, iclust_init
    
    
    def kmeans_plusplus(Xg, niter=200, seed=1, device=torch.device('cuda'), verbose=False):
        # Xg is number of spikes by number of features.
        # We are finding cluster centroids and assigning each spike to a centroid.
        vtot = torch.norm(Xg, 2, dim=1)**2
    
        n1 = vtot.shape[0]
        if n1 > 2**24:
            # This subsampling step is just for the candidate spikes to be considered
            # as new centroids. Sometimes need to subsample v2 since
            # torch.multinomial doesn't allow more than 2**24 elements. We're just
            # using this to sample some spikes, so it's fine to not use all of them.
            n2 = n1 - 2**24   # number of spikes to remove before sampling
            idx, rev_idx = subsample_idx(n1, n2)
            rev_idx = torch.from_numpy(rev_idx).to(device)
            subsample = True
        else:
            subsample = False
    
        torch.manual_seed(seed)
        np.random.seed(seed)
    
        ntry = 100  # number of candidate cluster centroids to test on each iteration
        n_spikes, n_features = Xg.shape
        # Need to store the spike features used for each cluster centroid (mu),
        # best variance explained so far for each spike (vexp0),
        # and the cluster assignment for each spike (iclust).
        mu = torch.zeros((niter, n_features), device = device)
        vexp0 = torch.zeros(n_spikes, device = device)
        iclust = torch.zeros((n_spikes,), dtype = torch.int, device = device)
    
        if verbose:
            log_performance(logger, header='clustering_qr.kpp, after var init')
    
        # On every iteration we choose one new centroid to keep.
        # We track how well n centroids so far explain each spike.
        # We ask, if we were to add another centroid, which spikes would that
        # increase the explained variance for and by how much?
        # We use ntry candidates on each iteration.
        for j in range(niter):
            # v2 is the un-explained variance so far for each spike
            v2 = torch.relu(vtot - vexp0)
    
            # We sample ntry new candidate centroids based on how much un-explained variance they have
            # more unexplained variance makes it more likely to be selected
            # Only one of these candidates will be added this iteration. 
            if subsample:
                isamp = rev_idx[torch.multinomial(v2[idx], ntry)]
            else:
                isamp = torch.multinomial(v2, ntry)
    
            try:
                # The new centroids to be tested, sampled from the spikes in Xg.
                Xc = Xg[isamp]
                # Variance explained for each spike for the new centroids.
                vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1)
                # Difference between variance explained for new centroids
                # and best explained variance so far across all iterations.
                # This gets relu-ed, since only the positive increases will actually
                # re-assign a spike to this new cluster
                dexp = torch.relu(vexp - vexp0.unsqueeze(1))
                # Sum all positive increases to determine additional explained variance
                # for each candidate centroid.
                vsum = dexp.sum(0)
                # Pick the candidate which increases explained variance the most 
                imax = torch.argmax(vsum)
    
                # For that centroid (Xc[imax]), determine which spikes actually get
                # more variance from it
                ix = dexp[:, imax] > 0
    
                iclust[ix] = j    # assign new cluster identity
                mu[j] = Xc[imax]  # spike features used as centroid for cluster j
                # Update variance explained for the spikes assigned to cluster j
                vexp0[ix] = vexp[ix, imax]
    
                # Delete large variables between iterations
                # to prevent excessive memory reservation.
                del(vexp)
                del(dexp)
    
            except torch.cuda.OutOfMemoryError:
                logger.debug(f"OOM in kmeans_plus_plus iter {j}, nsp: {Xg.shape[0]}, "
                             f"Xg size: {Xg.nbytes / (2**20):.2f} MB.")
                raise
    
        if verbose:
            log_performance(logger, header='clustering_qr.kpp, after loop')
    
        # NOTE: For very large datasets, we may end up needing to subsample Xg.
        # If the clustering above is done on a subset of Xg,
        # then we need to assign all Xgs here to get an iclust 
        # for ii in range((len(Xg)-1)//nblock +1):
        #     vexp = 2 * Xg[ii*nblock:(ii+1)*nblock] @ mu.T - (mu**2).sum(1)
        #     iclust[ii*nblock:(ii+1)*nblock] = torch.argmax(vexp, dim=-1)
    
        return iclust
    
    
    def subsample_idx(n1, n2):
        """Get boolean mask and reverse mapping for evenly distributed subsample.
        
        Parameters
        ----------
        n1 : int
            Size of index. Index is assumed to be sequential and not contain any
            missing values (i.e. 0, 1, 2, ... n1-1).
        n2 : int
            Number of indices to remove to create a subsample. Removed indices are
            evenly spaced across 
        
        Returns
        -------
        idx : np.ndarray
            Boolean mask, True for indices to be included in the subset.
        rev_idx : np.ndarray
            Map between subset indices and their position in the original index.
    
        Examples
        --------
        >>> subsample_idx(6, 3)
        array([False,  True, False,  True,  True, False], dtype=bool),
        array([1, 3, 4], dtype=int64)
    
        """
    
        remove = np.round(np.linspace(0, n1-1, n2)).astype(int)
        idx = np.ones(n1, dtype=bool)
        idx[remove] = False
        # Also need to map the indices from the subset back to indices for
        # the full tensor.
        rev_idx = idx.nonzero()[0]
    
        return idx, rev_idx
    
    
    def xy_templates(ops):
        iU = ops['iU'].cpu().numpy()
        iC = ops['iCC'][:, ops['iU']]
        #PID = st[:,5].long()
        xcup, ycup = ops['xc'][iU], ops['yc'][iU]
        xy = np.vstack((xcup, ycup))
        xy = torch.from_numpy(xy)
    
        return xy, iC
    
    
    def xy_up(ops):
        xcup, ycup = ops['xcup'], ops['ycup']
        xy = np.vstack((xcup, ycup))
        xy = torch.from_numpy(xy)
        iC = ops['iC'] 
    
        return xy, iC
    
    
    def x_centers(ops):
        k = ops.get('x_centers', None)
        if k is not None:
            # Use this as the input for k-means, either a number of centers
            # or initial guesses.
            approx_centers = k
        else:
            # NOTE: This automated method does not work well for 2D array probes.
            #       We recommend specifying `x_centers` manually for that case.
    
            # Originally bin_width was set equal to `dminx`, but decided it's better
            # to not couple this behavior with that setting. A bin size of 50 microns
            # seems to work well for NP1 and 2, tetrodes, and 2D arrays. We can make
            # this a parameter later on if it becomes a problem.
            bin_width = 50
            min_x = ops['xc'].min()
            max_x = ops['xc'].max()
    
            # Make histogram of x-positions with bin size roughly equal to dminx,
            # with a bit of padding on either end of the probe so that peaks can be
            # detected at edges.
            num_bins = int((max_x-min_x)/(bin_width)) + 4
            bins = np.linspace(min_x - bin_width*2, max_x + bin_width*2, num_bins)
            hist, edges = np.histogram(ops['xc'], bins=bins)
            # Apply smoothing to make peak-finding simpler.
            smoothed = gaussian_filter(hist, sigma=0.5)
            peaks, _ = find_peaks(smoothed)
            # peaks are indices, translate back to position in microns
            approx_centers = [edges[p] for p in peaks]
    
            # Use these as initial guesses for centroids in k-means to get
            # a more accurate value for the actual centers. If there's one or none,
            # just look for one centroid.
            if len(approx_centers) <= 1: approx_centers = 1
    
        centers, distortion = kmeans(ops['xc'], approx_centers, seed=5330)
    
        # TODO: Maybe use distortion to raise warning if it seems too large?
        # "The mean (non-squared) Euclidean distance between the observations passed
        #  and the centroids generated. Note the difference to the standard definition
        #  of distortion in the context of the k-means algorithm, which is the sum of
        #  the squared distances."
    
        # For example, could raise a warning if this is greater than dminx*2?
        # Most probes should satisfy that criteria.
    
        return centers
    
    
    def y_centers(ops):
        ycup = ops['ycup']
        dmin = ops['dmin']
        # TODO: May want to add the -dmin/2 in the future to center these, but
        #       this changes the results for testing so we need to wait until we can
        #       check it with simulations.
        centers = np.arange(ycup.min()+dmin-1, ycup.max()+dmin+1, 2*dmin)# - dmin/2
    
        return centers
    
    
    def get_nearest_centers(xy, xcent, ycent):
        # Get positions of all grouping centers
        ycent_pos, xcent_pos = np.meshgrid(ycent, xcent)
        ycent_pos = torch.from_numpy(ycent_pos.flatten())
        xcent_pos = torch.from_numpy(xcent_pos.flatten())
        # Compute distances from templates
        center_distance = (
            (xy[0,:] - xcent_pos.unsqueeze(-1))**2
            + (xy[1,:] - ycent_pos.unsqueeze(-1))**2
            )
        # Add some randomness in case of ties
        center_distance += 1e-20*torch.rand(center_distance.shape)
        # Get flattened index of x-y center that is closest to template
        minimum_distance = torch.min(center_distance, 0).indices
    
        return minimum_distance, xcent_pos, ycent_pos
    
    
    def run(ops, st, tF, mode='template', device=torch.device('cuda'),
            progress_bar=None, clear_cache=False, verbose=False):
    
        if mode == 'template':
            xy, iC = xy_templates(ops)
            iclust_template = st[:,1].astype('int32')
            xcup, ycup = ops['xcup'], ops['ycup']
        else:
            xy, iC = xy_up(ops)
            iclust_template = st[:,5].astype('int32')
            xcup, ycup = ops['xcup'], ops['ycup']
    
        dmin = ops['dmin']
        dminx = ops['dminx']
        nskip = ops['settings']['cluster_downsampling']
        n_neigh = ops['settings']['cluster_neighbors']
        max_sub = ops['settings']['max_cluster_subset']
        ycent = y_centers(ops)
        xcent = x_centers(ops)
        nsp = st.shape[0]
        nearest_center, _, _ = get_nearest_centers(xy, xcent, ycent)
        total_centers = np.unique(nearest_center).size
        
        clu = np.zeros(nsp, 'int32')
        Wall = torch.zeros((0, ops['Nchan'], ops['settings']['n_pcs']))
        Nfilt = None
        nearby_chans_empty = 0
        nmax = 0
        prog = tqdm(np.arange(len(xcent)), miniters=20 if progress_bar else None,
                    mininterval=10 if progress_bar else None)
        t = 0
        v = False
        
        try:
            for jj in prog:
                for kk in np.arange(len(ycent)):
                    # Get data for all templates that were closest to this x,y center.
                    ii = kk + jj*ycent.size
                    if ii not in nearest_center:
                        # No templates are nearest to this center, skip it.
                        continue
                    else:
                        t += 1
                    ix = (nearest_center == ii)
                    ntemp = ix.sum()
    
                    v = False
                    if t % 10 == 0:
                        log_performance(
                            logger,
                            header=f'Cluster center: {ii} ({t}/{total_centers})'
                            )
                        if verbose:
                            v = True
    
                    Xd, igood, ichan = get_data_cpu(
                        ops, xy, iC, iclust_template, tF, ycent[kk], xcent[jj],
                        dmin=dmin, dminx=dminx, ix=ix,
                        )
                    if Xd is None:
                        nearby_chans_empty += 1
                        continue
    
                    logger.debug(f'Center {ii} | Xd shape: {Xd.shape} | ntemp: {ntemp}')
                    if verbose and Xd.nelement() > 10**8:
                        logger.info(f'Resetting cuda memory stats for Center {ii}')
                        if device == torch.device('cuda'):
                            torch.cuda.reset_peak_memory_stats(device)
                        v = True
                    if Xd.shape[0] < 1000:
                        iclust = torch.zeros((Xd.shape[0],))
                    else:
                        if mode == 'template':
                            st0 = st[igood,0]/ops['fs']
                        else:
                            st0 = None
    
                        # find new clusters
                        iclust, iclust0, M, _ = cluster(
                            Xd, nskip=nskip, n_neigh=n_neigh, max_sub=max_sub,
                            lam=1, seed=5, device=device, verbose=v
                            )
    
                        if clear_cache:
                            if v:
                                log_performance(logger, header='clustering_qr before gc')
                            gc.collect()
                            torch.cuda.empty_cache()
                            if v:
                                log_performance(logger, header='clustering_qr after gc')
    
                        xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)
    
                        xtree, tstat = swarmsplitter.split(
                            Xd.numpy(), xtree, tstat,iclust, my_clus, meta=st0
                            )
    
                        iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)
    
                    if v:
                        log_performance(logger, header='clustering_qr.run, after iclust')
    
                    clu[igood] = iclust + nmax
                    Nfilt = int(iclust.max() + 1)
                    nmax += Nfilt
    
                    # we need the new templates here         
                    W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs']))
                    for j in range(Nfilt):
                        w = Xd[iclust==j].mean(0)
                        W[j, ichan, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()
                    
                    Wall = torch.cat((Wall, W), 0)
    
                    if progress_bar is not None:
                        progress_bar.emit(int((kk+1) / len(ycent) * 100))
        except:
            logger.exception(f'Error in clustering_qr.run on center {ii}')
            logger.debug(f'Xd shape: {Xd.shape}')
            logger.debug(f'Nfilt: {Nfilt}')
            logger.debug(f'num spikes: {nsp}')
            try:
                logger.debug(f'iclust shape: {iclust.shape}')
            except UnboundLocalError:
                logger.debug('iclust not yet assigned')
                pass
            raise
    
        if nearby_chans_empty == total_centers:
            raise ValueError(
                f'`get_data_cpu` never found suitable channels in `clustering_qr.run`.'
                f'\ndmin, dminx, and xcenter are: {dmin, dminx, xcup.mean()}'
            )
    
        if Wall.sum() == 0:
            # Wall is empty, unspecified reason
            raise ValueError(
                'Wall is empty after `clustering_qr.run`, cannot continue clustering.'
            )
    
        return clu, Wall
    
    
    def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32,
                     ix=None, merge_dim=True):
        PID =  torch.from_numpy(PID).long()
    
        #iU = ops['iU'].cpu().numpy()
        #iC = ops['iCC'][:, ops['iU']]    
        #xcup, ycup = ops['xc'][iU], ops['yc'][iU]
        #xy = np.vstack((xcup, ycup))
        #xy = torch.from_numpy(xy)
        
        y0 = ycenter # xy[1].mean() - ycenter
        x0 = xcenter #xy[0].mean() - xcenter
    
        if ix is None:
            ix = torch.logical_and(
                torch.abs(xy[1] - y0) < dmin,
                torch.abs(xy[0] - x0) < dminx
                )
        igood = ix[PID].nonzero()[:,0]
    
        if len(igood) == 0:
            return None, None, None
    
        pid = PID[igood]
        data = tF[igood]
        nspikes, nchanraw, nfeatures = data.shape
        ichan, imap = torch.unique(iC[:, ix], return_inverse=True)
        nchan = ichan.nelement()
    
        dd = torch.zeros((nspikes, nchan, nfeatures))
        for k,j in enumerate(ix.nonzero()[:,0]):
            ij = torch.nonzero(pid==j)[:, 0]
            dd[ij.unsqueeze(-1), imap[:,k]] = data[ij]
    
        if merge_dim:
            Xd = torch.reshape(dd, (nspikes, -1))
        else:
            # Keep channels and features separate
            Xd = dd
    
        return Xd, igood, ichan
    
    
    def assign_clust(rows_neigh, iclust, kn, tones2, nclust):    
        n_spikes = len(iclust)
    
        ij = torch.vstack((rows_neigh.flatten(), iclust[kn].flatten()))
        xN = coo(ij, tones2.flatten(), (n_spikes, nclust))
        
        xN = xN.to_dense() 
        iclust = torch.argmax(xN, 1)
    
        return iclust
    
    def assign_iclust0(Xg, mu):
        vv = Xg @ mu.T
        nm = (mu**2).sum(1)
        iclust = torch.argmax(2*vv-nm, 1)
        return iclust
    

    back to top

    Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
    The source code of Software Heritage itself is available on our development forge.
    The source code files archived by Software Heritage are available under their own copyright and licenses.
    Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API