Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://doi.org/10.5281/zenodo.3597474
27 August 2025, 14:32:56 UTC
  • Code
  • Branches (0)
  • Releases (7)
  • Visits
    • Branches
    • Releases
      • 7
      • 7
      • 6
      • 5
      • 4
      • 3
      • 2
      • 1
    • b6d380b
    • /
    • MouseLand-Kilosort-8f396c7
    • /
    • kilosort
    • /
    • io.py
    Raw File Download

    To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
    Select below a type of object currently browsed in order to display its associated SWHID and permalink.

    • content
    • directory
    • snapshot
    • release
    origin badgecontent badge
    swh:1:cnt:87c4127cfd1bc0cbe01374d11adf712e69a71c34
    origin badgedirectory badge
    swh:1:dir:5a94e6584e2d42b7e8e4bdb91a2e2f831f22c766
    origin badgesnapshot badge
    swh:1:snp:cde65436e19c7a8c490b366892685f110e246dcf
    origin badgerelease badge
    swh:1:rel:4aa70d8c6bf4aecf2297faaca70d7dba2d9e7b48

    This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
    Select below a type of object currently browsed in order to generate citations for them.

    • content
    • directory
    • snapshot
    • release
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    Generate software citation in BibTex format (requires biblatex-software package)
    Generating citation ...
    io.py
    import json
    from pathlib import Path
    from typing import Tuple, Union
    import os, shutil
    import warnings
    from concurrent.futures import ThreadPoolExecutor
    import time
    import logging
    logger = logging.getLogger(__name__)
    
    from scipy.io import loadmat
    import numpy as np
    import torch
    from torch.fft import fft, ifft, fftshift
    
    from kilosort import CCG
    from kilosort.preprocessing import get_drift_matrix, fft_highpass
    from kilosort.postprocessing import (
        remove_duplicates, compute_spike_positions, make_pc_features
        )
    from kilosort.utils import log_performance
    
    _torch_warning = ".*PyTorch does not support non-writable tensors"
    
    
    def find_binary(data_dir: Union[str, os.PathLike]) -> Path:
        """Find binary file in `data_dir`."""
    
        data_dir = Path(data_dir)
        filenames = list(data_dir.glob('*.bin')) + list(data_dir.glob('*.bat')) \
                    + list(data_dir.glob('*.dat')) + list(data_dir.glob('*.raw'))
        if len(filenames) == 0:
            raise FileNotFoundError(
                f'No binary file found in {data_dir}. Expected extensions are:\n'
                '*.bin, *.bat, *.dat, or *.raw.'
                )
    
        # If there are multiple binary files, find one with "ap" tag
        if len(filenames) > 1:
            filenames = [f for f in filenames if 'ap.bin' in f.as_posix()]
    
        # If there is still more than one, raise an error, user needs to specify
        # full path.
        if len(filenames) > 1:
            raise ValueError('Multiple binary files in folder with "ap" tag, '
                             'please specify filename')
    
        return filenames[0]
    
    
    def load_probe(probe_path):
        """Load a .mat probe file from Kilosort2, or a PRB file and returns a dictionary
        
        adapted from https://github.com/MouseLand/pykilosort/blob/5712cfd2722a20554fa5077dd8699f68508d1b1a/pykilosort/utils.py#L592
    
        """
        probe = {}
        probe_path = Path(probe_path).resolve()
        required_keys = ['chanMap', 'yc', 'xc', 'n_chan']
    
        if probe_path.suffix == '.prb':
            # Support for PRB files.
            # !DOES NOT WORK FOR PHASE3A PROBES WITH DISCONNECTED CHANNELS!
            # Also does not remove reference channel in PHASE3B probes
            contents = probe_path.read_text()
            metadata = {'np': np}
            exec(contents, {}, metadata)
            # TODO: Figure out an alternative for exec, it's a security risk.
    
            probe['chanMap'] = []
            probe['xc'] = []
            probe['yc'] = []
            probe['kcoords'] = []
            probe['n_chan'] = 0 
            for cg in sorted(metadata['channel_groups']):
                d = metadata['channel_groups'][cg]
                ch = d['channels']
                pos = d.get('geometry', {})
                probe['chanMap'].append(ch)
                probe['n_chan'] += len(ch)
                probe['xc'].append([pos[c][0] for c in ch])
                probe['yc'].append([pos[c][1] for c in ch])
                probe['kcoords'].append([cg for c in ch])
            probe['chanMap'] = np.concatenate(probe['chanMap']).ravel().astype(np.int32)
            probe['xc'] = np.concatenate(probe['xc']).astype('float32')
            probe['yc'] = np.concatenate(probe['yc']).astype('float32')
            probe['kcoords'] = np.concatenate(probe['kcoords']).astype('float32')
    
        elif probe_path.suffix == '.mat':
            mat = loadmat(probe_path)
            connected = mat['connected'].ravel().astype('bool')
            probe['xc'] = mat['xcoords'].ravel().astype(np.float32)[connected]
            nc = len(probe['xc'])
            probe['yc'] = mat['ycoords'].ravel().astype(np.float32)[connected]
            kc = mat.get('kcoords', None)
            if kc is None:
                kc = np.zeros(nc).ravel().astype(np.float32)
            else:
                kc = kc.ravel().astype(np.float32)[connected]
            probe['kcoords'] = kc
            probe['chanMap'] = (mat['chanMap'] - 1).ravel().astype(np.int32)[connected]  # NOTE: 0-indexing in Python
            probe['n_chan'] = (mat['chanMap'] - 1).ravel().astype(np.int32).shape[0]  # NOTE: should match the # of columns in the raw data
    
        elif probe_path.suffix == '.json':
            with open(probe_path, 'r') as f:
                probe = json.load(f)
            for k in list(probe.keys()):
                # Convert lists back to arrays
                v = probe[k]
                if isinstance(v, list):
                    dtype = np.int32 if k == 'chanMap' else np.float32
                    probe[k] = np.array(v, dtype=dtype)
    
        else:
            raise ValueError(
                f"Unrecognized extension for probe: {probe_path.name}\n"
                "Recognized extensions are .mat, .prb, or .json."
                )
    
        for n in required_keys:
            if n not in probe.keys():
                raise ValueError(f'Missing required key: {n} after loading probe.')
    
        # Verify that all arrays have the same size.
        size = None
        for k, v in probe.items():
            if isinstance(v, np.ndarray):
                if size is None:
                    size = v.size
                elif size != v.size:
                    raise ValueError(
                        f"All probe variables must have the same length."
                    )
    
        return probe
    
      
    def save_probe(probe_dict, filepath):
        """Save a probe dictionary to a .json text file.
    
        Parameters
        ----------
        probe_dict : dict
            A dictionary containing probe information in the format expected by
            Kilosort4, with keys 'chanMap', 'xc', 'yc', and 'kcoords'.
        filepath : str or pathlib.Path
            Location where .json file should be stored.
    
        Raises
        ------
        RuntimeWarning
            If filepath does not end in '.json'
        
        """
    
        if Path(filepath).suffix != '.json':
            raise RuntimeWarning(
                'Saving json probe to a file whose suffix is not .json. '
                'kilosort.io.load_probe will not recognize this file.' 
            )
    
        d = probe_dict.copy()
        # Convert arrays to lists, since arrays aren't json-able.
        for k in list(d.keys()):
            v = d[k]
            if isinstance(v, np.ndarray):
                d[k] = v.tolist()
    
        # Verify that all lists are the same length.
        length = None
        for k, v in d.items():
            if isinstance(v, list):
                if length is None:
                    length = len(v)
                elif length != len(v):
                    raise ValueError(
                        f"All probe variables must have the same length."
                    )
    
        # Create parent directories if they do not exist, then save probe.
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        with open(filepath, 'w') as f:
            f.write(json.dumps(d))
    
    
    def remove_bad_channels(probe, bad_channels):
        """Create a probe dictionary with listed channels (data rows) removed.
        
        Parameters
        ----------
        probe : dict.
            A Kilosort4 probe dictionary, as returned by `kilosort.io.load_probe`.
        bad_channels : list.
            A list of channel indices (rows in the binary file) that should not be
            included in sorting. Listing channels here is equivalent to excluding
            them from the probe dictionary.
    
        Returns
        -------
        probe : dict.
        
        """
        probe = probe.copy()
    
        bad_idx = np.empty_like(bad_channels, dtype=int)
        for i, k in enumerate(bad_channels):
            try:
                idx = np.where(probe['chanMap'] == k)[0][0]
            except IndexError:
                raise IndexError(f"Channel '{k}' was not in probe['chanMap']")
            bad_idx[i] = idx
        
        for k in ['xc', 'yc', 'chanMap', 'kcoords']:
            probe[k] = np.delete(probe[k], bad_idx)
        probe['n_chan'] = probe['n_chan'] - bad_idx.size
    
        return probe
    
    
    def select_shank(probe, shank_idx):
        """Create a probe dictionary containing channels from shank `shank_idx` only.
        
        Parameters
        ----------
        probe : dict.
            A Kilosort4 probe dictionary, as returned by `kilosort.io.load_probe`.
        shank_idx : float.
            If not None, only channels from the specified shank index will be used.
    
        Returns
        -------
        probe : dict.
    
        """
        probe = probe.copy()
        keep = (probe['kcoords'] == shank_idx)
        n = keep.sum()
        if n == 0:
            raise ValueError(f"Selected empty shank index: {shank_idx}, can't sort.")
        else:
            probe['n_chan'] = n
            for k in ['xc', 'yc', 'chanMap', 'kcoords']:
                probe[k] = probe[k][keep]
    
        return probe
    
    
    def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
                    data_dtype=None, save_extra_vars=False,
                    save_preprocessed_copy=False, skip_dat_path=False):
        """Save sorting results to disk in a format readable by Phy.
    
        Parameters
        ----------
        st : np.ndarray
            3-column array of peak time (in samples), template, and thresold
            amplitude for each spike.
        clu : np.ndarray
            1D vector of cluster ids indicating which spike came from which cluster,
            same shape as `st[:,0]`.
        tF : torch.Tensor
            PC features for each spike, with shape
            (n_spikes, nearest_chans, n_pcs)
        Wall : torch.Tensor
            PC feature representation of spike waveforms for each cluster, with shape
            (n_clusters, n_channels, n_pcs).
        probe : dict; optional.
            A Kilosort4 probe dictionary, as returned by `kilosort.io.load_probe`.
        ops : dict
            Dictionary storing settings and results for all algorithmic steps.
        imin : int
            Minimum sample index used by BinaryRWFile, exported spike times will
            be shifted forward by this number.
        results_dir : pathlib.Path; optional.
            Directory where results should be saved.
        data_dtype : str or type; optional.
            dtype of data in binary file, like `'int32'` or `np.uint16`. By default,
            dtype is assumed to be `'int16'`.
        save_extra_vars : bool; default=False.
            If True, save tF and Wall to disk along with copies of st, clu and
            amplitudes with no postprocessing applied.
        save_preprocessed_copy : bool; default=False.
            If True, save a pre-processed copy of the data (including drift
            correction) to `temp_wh.dat` in the results directory and format Phy
            output to use that copy of the data.
        skip_dat_path : bool; default=False.
            If True, will save `dat_path = 'no_path.bin'` in `params.py` in place
            of a real filename. This is done to prevent an error in Phy when filename
            has an unexpected format, like when using a `file_object` loaded from
            an external data format through SpikeInterface. The full filename(s) will
            still be included in `params.py` for reference, but will be commented out.
    
        
        Returns
        -------
        results_dir : pathlib.Path.
            Directory where results are saved.
        similar_templates : np.ndarray.
            Similarity score between each pair of clusters, computed as correlation
            between clusters. Shape (n_clusters, n_clusters).
        is_ref : np.ndarray.
            1D boolean array with shape (n_clusters,) indicating whether each
            cluster is refractory.
        est_contam_rate : np.ndarray.
            Contamination rate for each cluster, computed as fraction of refractory
            period violations relative to expectation based on a Poisson process.
            Shape (n_clusters,).
        kept_spikes : np.ndarray.
            Boolean mask with shape (n_spikes,) that is False for spikes that were
            removed by `kilosort.postprocessing.remove_duplicate_spikes`
            and True otherwise.
    
        Notes
        -----
        For documentation of the saved files, see
        https://kilosort.readthedocs.io/en/latest/export_files.html
    
        """
    
        if results_dir is None:
            results_dir = ops['data_dir'].joinpath('kilosort4')
        results_dir = Path(results_dir)
        results_dir.mkdir(exist_ok=True)
    
        # probe properties
        chan_map = probe['chanMap']
        channel_positions = np.stack((probe['xc'], probe['yc']), axis=-1)
        np.save((results_dir / 'channel_map.npy'), chan_map)
        np.save((results_dir / 'channel_positions.npy'), channel_positions)
        np.save((results_dir / 'channel_shanks.npy'), probe['kcoords'])
    
        # whitening matrix
        whitening_mat = ops['Wrot']
        np.save((results_dir / 'whitening_mat_dat.npy'), whitening_mat.cpu())
        # NOTE: commented out for reference, this was different in KS 2.5 because
        #       the binary file was already whitened.
        # whitening_mat = 0.005 * np.eye(len(chan_map), dtype='float32')
        whitening_mat_inv = torch.inverse(
            whitening_mat
            + 1e-5 * torch.eye(whitening_mat.shape[0]).to(whitening_mat.device)
            )
        np.save((results_dir / 'whitening_mat.npy'), whitening_mat.cpu())
        np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv.cpu())
        log_performance(logger, level='debug', header='save_to_phy, whitening',
                        reset=True)
    
        # spike properties
        spike_times = st[:,0].astype('int64') + imin  # shift by minimum sample index
        spike_templates = st[:,1].astype('int32')
        spike_clusters = clu
        xs, ys = compute_spike_positions(st, tF, ops)
        spike_positions = np.vstack([xs, ys]).T
        amplitudes = torch.norm(tF, dim=[-2,-1]).cpu().numpy()
        log_performance(logger, level='debug', header='save_to_phy, spike positions',
                        reset=True)
    
        # remove duplicate (artifact) spikes
        spike_times, spike_clusters, kept_spikes = remove_duplicates(
            spike_times, spike_clusters, dt=ops['duplicate_spike_bins']
        )
        amp = amplitudes[kept_spikes]
        spike_templates = spike_templates[kept_spikes]
        spike_positions = spike_positions[kept_spikes]
        np.save((results_dir / 'spike_times.npy'), spike_times)
        np.save((results_dir / 'spike_templates.npy'), spike_clusters)
        np.save((results_dir / 'spike_clusters.npy'), spike_clusters)
        np.save((results_dir / 'spike_positions.npy'), spike_positions)
        np.save((results_dir / 'spike_detection_templates.npy'), spike_templates)
        np.save((results_dir / 'amplitudes.npy'), amp)
        # Save spike mask so that it can be applied to other variables if needed
        # when loading results.
        np.save((results_dir / 'kept_spikes.npy'), kept_spikes)
        log_performance(logger, level='debug', header='save_to_phy, remove duplicates',
                        reset=True)
    
        # template properties
        similar_templates = CCG.similarity(Wall, ops['wPCA'].contiguous(), nt=ops['nt'])
        template_amplitudes = torch.norm(Wall, dim=[-2,-1]).cpu().numpy()
        templates = (Wall.unsqueeze(-1).cpu() * ops['wPCA'].cpu()).sum(axis=-2).numpy()
        templates = templates.transpose(0,2,1)
        templates_ind = np.tile(np.arange(Wall.shape[1])[np.newaxis, :], (templates.shape[0],1))
        np.save((results_dir / 'similar_templates.npy'), similar_templates)
        np.save((results_dir / 'templates.npy'), templates)
        np.save((results_dir / 'templates_ind.npy'), templates_ind)
        log_performance(logger, level='debug', header='save_to_phy, template props',
                        reset=True)
        
        # pc features
        if save_extra_vars:
            # Save tF first since it gets updated in-place
            np.save(results_dir / 'tF.npy', tF.cpu().numpy())
        # This will momentarily copy tF which is pretty large, but it's on CPU
        # so the extra memory hopefully won't be an issue.
        tF = tF[kept_spikes]
        pc_features, pc_feature_ind = make_pc_features(
            ops, spike_templates, spike_clusters, tF
            )
        np.save(results_dir / 'pc_features.npy', pc_features)
        np.save(results_dir / 'pc_feature_ind.npy', pc_feature_ind)
        log_performance(logger, level='debug', header='save_to_phy, pc features',
                        reset=True)
    
        # contamination ratio
        acg_threshold = ops['settings']['acg_threshold']
        ccg_threshold = ops['settings']['ccg_threshold']
        is_ref, est_contam_rate = CCG.refract(spike_clusters, spike_times / ops['fs'],
                                              acg_threshold=acg_threshold,
                                              ccg_threshold=ccg_threshold)
        log_performance(logger, level='debug', header='save_to_phy, est contam',
                        reset=True)
    
        # write properties to *.tsv
        stypes = ['ContamPct', 'Amplitude', 'KSLabel']
        ks_labels = [['mua', 'good'][int(r)] for r in is_ref]
        props = [est_contam_rate*100, template_amplitudes, ks_labels]
        for stype, prop in zip(stypes, props):
            with open((results_dir / f'cluster_{stype}.tsv'), 'w') as f:
                f.write(f'cluster_id\t{stype}\n')
                for i,p in enumerate(prop):
                    if stype != 'KSLabel':
                        f.write(f'{i}\t{p:.1f}\n')
                    else:
                        f.write(f'{i}\t{p}\n')
            if stype == 'KSLabel':
                shutil.copyfile((results_dir / f'cluster_{stype}.tsv'), 
                                (results_dir / f'cluster_group.tsv'))
    
        # params.py
        dtype = "'int16'" if data_dtype is None else f"'{data_dtype}'"
        params = {
            'n_channels_dat': ops['settings']['n_chan_bin'],
            'offset': 0,
            'sample_rate': ops['settings']['fs']
            }
        if save_preprocessed_copy:
            dat_path = results_dir / 'temp_wh.dat'
            params['dtype'] = "'int16'"
            params['hp_filtered'] = True
            params['dat_path'] = f"'{dat_path.resolve().as_posix()}'"
        else:
            f = ops['settings']['filename']
            if not isinstance(f, list): f = [f]
            dat_path = [Path(p).resolve().as_posix() for p in f]
            params['dtype'] = dtype
            params['hp_filtered'] = False
            params['dat_path'] = dat_path
    
        with open((results_dir / 'params.py'), 'w') as f: 
            for key in params.keys():
                if (key == 'dat_path') and skip_dat_path:
                    f.write(f'dat_path = "no_path.bin"\n')
                    f.write('# ')
                f.write(f'{key} = {params[key]}\n')
    
        if save_extra_vars:
            # Also save Wall, for easier debugging/analysis
            np.save(results_dir / 'Wall.npy', Wall.cpu().numpy())
            # And full st, clu, amp arrays with no spikes removed
            np.save(results_dir / 'full_st.npy', st)
            np.save(results_dir / 'full_clu.npy', clu)
            np.save(results_dir / 'full_amp.npy', amplitudes)
    
        # Remove cached .phy results if present from running Phy on a previous
        # version of results in the same directory.
        phy_cache_path = Path(results_dir / '.phy')
        if phy_cache_path.is_dir():
            shutil.rmtree(phy_cache_path)
    
        return results_dir, similar_templates, is_ref, est_contam_rate, kept_spikes
    
    
    def save_ops(ops, results_dir=None):
        """Save intermediate `ops` dictionary to `results_dir/ops.npy`."""
    
        if results_dir is None:
            results_dir = Path(ops['data_dir']) / 'kilosort4'
        else:
            results_dir = Path(results_dir)
        results_dir.mkdir(exist_ok=True)
    
        ops = ops.copy()
        # Convert paths to strings before saving, otherwise ops can only be loaded
        # on the system that originally ran the code (causes problems for tests).
        ops['settings']['results_dir'] = str(results_dir)
        # TODO: why do these get saved twice?
        ops['filename'] = str(ops['filename'])
        ops['data_dir'] = str(ops['data_dir'])
        ops['settings']['filename'] = str(ops['settings']['filename'])
        ops['settings']['data_dir'] = str(ops['settings']['data_dir'])
    
        # Convert pytorch tensors to numpy arrays before saving, otherwise loading
        # ops on a different system may not work (if saved from GPU, but loaded
        # on a system with only CPU).
        ops['is_tensor'] = []
        for k, v in ops.items():
            if isinstance(v, torch.Tensor):
                ops[k] = v.cpu().numpy()
                ops['is_tensor'].append(k)
        ops['preprocessing'] = {k: v.cpu().numpy()
                                for k, v in ops['preprocessing'].items()}
    
        np.save(results_dir / 'ops.npy', np.array(ops))
    
    
    def load_ops(ops_path, device=None):
        """Load a saved `ops` dictionary and convert some arrays to tensors."""
        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
    
        ops = np.load(ops_path, allow_pickle=True).item()
        for k, v in ops.items():
            if k in ops['is_tensor']:
                ops[k] = torch.from_numpy(v).to(device)
        # TODO: Why do we have one copy of this saved as numpy, one as tensor,
        #       at different levels?
        ops['preprocessing'] = {k: torch.from_numpy(v).to(device)
                                for k,v in ops['preprocessing'].items()}
    
        return ops
    
    
    def bfile_from_ops(ops=None, ops_path=None, filename=None, device=None):
        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
    
        if ops is None:
            if ops_path is not None:
                ops = load_ops(ops_path, device=device)
            else:
                raise ValueError('Must specify either `ops` or `ops_path`.')
        
        if filename is None:
            # Separate so that the same settings can be applied to the same binary
            # file at a new path, like when running pytests.
            filename = ops['filename']
        
        bfile = BinaryFiltered(
            filename, ops['n_chan_bin'], fs=ops['fs'], NT=ops['batch_size'],
            nt=ops['nt'], nt0min=ops['nt0min'], chan_map=ops['probe']['chanMap'],
            hp_filter=ops['fwav'], whiten_mat=ops['Wrot'], dshift=ops['dshift'],
            device=device, do_CAR=ops['do_CAR'], artifact_threshold=ops['artifact_threshold'],
            invert_sign=ops['invert_sign'], dtype=ops['data_dtype'], tmin=ops['tmin'],
            tmax=ops['tmax'], shift=ops['shift'], scale=ops['scale']
            )
    
        return bfile
    
    
    class BinaryRWFile:
    
        supported_dtypes = ['int16', 'uint16', 'int32', 'float32']
    
        def __init__(self, filename, n_chan_bin: int, fs: int = 30000, 
                     NT: int = 60000, nt: int = 61, nt0min: int = 20,
                     device: torch.device = None, write: bool = False,
                     dtype: str = None, tmin: float = 0.0, tmax: float = np.inf,
                     shift=None, scale=None, file_object=None):
            """
            Creates/Opens a BinaryFile for reading and/or writing data that acts like numpy array
    
            * always assume int16 files *
    
            adapted from https://github.com/MouseLand/suite2p/blob/main/suite2p/io/binary.py
            
            Parameters
            ----------
            filename : Path-like or list of Path-likes.
                The filename of the file(s) to read from or write to
            n_chan_bin : int
                number of channels
            file_object : array-like file object; optional.
                Must have 'shape' and 'dtype' attributes and support array-like
                indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy
                array or memmap.
                Note: `filename` is effectively ignored if `file_object` is not None.
    
            """
            # Extra casting statements to ensure that 64-bit values are used
            # for time samples on very long recordings.
            self.fs = np.float64(fs)
            self.n_chan_bin = n_chan_bin
            self.filename = filename
            self.NT = np.int64(NT) 
            self.nt = np.int64(nt) 
            self.nt0min = np.int64(nt0min)
            self.shift = shift
            self.scale = scale
            tmin = np.float64(tmin)
            tmax = np.float64(tmax)
    
            if device is None:
                if torch.cuda.is_available():
                    device = torch.device('cuda')
                else:
                    device = torch.device('cpu')
            self.device = device
            self.uint_set_warning = True
            self.writable = write
            self.mode = 'w+' if write else 'r'
    
            if file_object is not None:
                dtype = file_object.dtype
            if dtype is None:
                dtype = 'int16'
            self.dtype = dtype
    
            if isinstance(filename, list) and len(filename) > 1:
                if file_object is None:
                    file_object = BinaryFileGroup(
                        filenames=filename, n_channels=n_chan_bin, dtype=dtype
                        )
                else:
                    logger.info(
                        "`file_object is not None`, "
                        "it will be used instead of loading from filename."
                        )
            self.file_object = file_object
    
            if str(self.dtype) not in self.supported_dtypes:
                message = f"""
                    {self.dtype} is not supported and may result in unexpected
                    behavior or errors. Supported types are:\n
                    {self.supported_dtypes}
                    """
                warnings.warn(message, RuntimeWarning)
    
            # Must come after dtype since dtype is necessary for nbytesread
            if file_object is None:
                self.total_samples = get_total_samples(filename, n_chan_bin, dtype)
            else:
                n, c = file_object.shape
                assert c == n_chan_bin
                self.total_samples = n
    
            self.imin = max(np.int64(tmin*self.fs), 0)
            self.imax = self.total_samples if tmax==np.inf \
                        else min(np.int64(tmax*self.fs), self.total_samples)
            self.n_batches = np.int64(np.ceil(self.n_samples / self.NT))
    
            # Check if last batch is too small. If so, drop those samples.
            a, b = self.get_batch_edges(self.n_batches-1)
            batch_size = b - a - self.nt
            if batch_size < self.nt:
                self.n_batches -= 1
                self.imax -= batch_size
    
    
        @property
        def n_samples(self) -> int:
            """total number of samples in the file."""
            return self.imax - self.imin
    
        @property
        def shape(self) -> Tuple[int, int]:
            """
            The dimensions of the data in the file
            Returns
            -------
            n_samples: int
                number of samples
            n_chan_bin: int
                number of channels
            """
            return self.n_samples, self.n_chan_bin
    
        @property
        def size(self) -> int:
            """
            Returns the total number of data points
    
            Returns
            -------
            size: int
            """
            return np.prod(np.array(self.shape).astype(np.int64))
        
        @property
        def file(self):
            """Get reference to in-memory file object or open a memmap to file."""
            if self.file_object is not None:
                file = self.file_object
            else:
                f = self.filename
                if isinstance(f, list):
                    f = f[0]
                file = np.memmap(f, mode=self.mode, dtype=self.dtype,
                                 shape=(self.total_samples, self.n_chan_bin))
            return file
            
        def __enter__(self):
            return self
    
        def __setitem__(self, *items):
            if not self.writable:
                raise ValueError('Binary file was loaded as read-only.')
    
            idx, data = items
            # Shift indices by minimum sample index
            sample_indices = self._get_shifted_indices(idx)
            # Shift data to pos-only
            if self.dtype == 'uint16':
                data = data + 2**15
                if self.uint_set_warning:
                    # Inform user of shift to hopefully avoid confusion, but only
                    # do this once per bfile.
                    logger.info(
                        "NOTE: When setting new values for uint16 data, 2**15 will "
                        "be added to the given values before writing to file."
                        )
                    self.uint_set_warning = False
            # Convert back from float to file dtype
            data = data.astype(self.dtype)
    
            self.file[sample_indices] = data
            if self.file_object is None:
                self.file.flush()
    
    
        def __getitem__(self, *items):
            idx, *crop = items
            # Shift indices by minimum sample index.
            sample_indices = self._get_shifted_indices(idx)
            samples = self.file[sample_indices]
            
            if self.dtype == 'uint16':
                # Shift data to +/- 2**15
                samples = samples.astype('float32')
                samples = samples - 2**15
    
            # Typically only need to be used with float32 data
            if self.scale is not None:
                samples = samples * self.scale
            if self.shift is not None:
                samples = samples + self.shift
    
            return samples
        
        
        def _get_shifted_indices(self, idx):
            if not isinstance(idx, tuple): idx = tuple([idx])
            new_idx = []
    
            i = idx[0]
            if isinstance(i, slice):
                # Time dimension
                start = self.imin if i.start is None else i.start + self.imin
                stop = self.imax if i.stop is None else min(i.stop + self.imin, self.imax)
                new_idx.append(slice(start, stop, i.step))
            else:
                new_idx.append(i)
    
            if len(idx) == 2:
                # Channel dimension, should be no others after this.
                # No adjustments needed.
                new_idx.append(idx[1])
    
            return tuple(new_idx)
    
        def get_batch_edges(self, ibatch):
            if ibatch==0:
                bstart = self.imin
                bend = self.imin + self.NT + self.nt
            else:
                # Casting to int64 is to prevent overflow for long recordings.
                # It's done multiple times because python is stubborn about
                # switching things back to default types.
                ibatch = np.int64(ibatch)
                bstart = np.int64(self.imin + (ibatch * self.NT) - self.nt)
                bend = min(self.imax, np.int64(bstart + self.NT + 2*self.nt))
    
            return bstart, bend
    
        def padded_batch_to_torch(self, ibatch, return_inds=False):
            """ read batches from file """
    
            bstart, bend = self.get_batch_edges(ibatch)
            data = self.file[bstart : bend]
            data = data.T
    
            if self.dtype == 'uint16':
                # Shift data to +/- 2**15
                data = data.astype('float32')
                data = data - 2**15
    
            # Typically only need to be used with float32 data
            if self.scale is not None:
                data = data * self.scale
            if self.shift is not None:
                data = data + self.shift
    
            nsamp = data.shape[-1]
            X = torch.zeros((self.n_chan_bin, self.NT + 2*self.nt), device=self.device)
    
            with warnings.catch_warnings():
                # Don't need this, we know about the warning and it doesn't cause
                # any problems. Doing this the "correct" way is much slower.
                warnings.filterwarnings("ignore", message=_torch_warning)
    
                # fix the data at the edges for the first and last batch
                if ibatch == 0:
                    X[:, self.nt : self.nt+nsamp] = torch.from_numpy(data).to(self.device).float()
                    X[:, :self.nt] = X[:, self.nt : self.nt+1]
                    bstart = self.imin - self.nt
                elif ibatch == self.n_batches-1:
                    X[:, :nsamp] = torch.from_numpy(data).to(self.device).float()
                    X[:, nsamp:] = X[:, nsamp-1:nsamp]
                    bend += self.nt
                else:
                    X[:] = torch.from_numpy(data).to(self.device).float()
    
            inds = [bstart, bend]
            if return_inds:
                return X, inds
            else:
                return X
            
    
    def get_total_samples(filename, n_channels, dtype=np.int16):
        """Count samples in binary file given dtype and number of channels."""
        if isinstance(filename, list):
            filename = filename[0]
        bytes_per_value = np.dtype(dtype).itemsize
        bytes_per_sample = np.int64(bytes_per_value * n_channels)
        total_bytes = os.path.getsize(filename)
        samples = np.float64(total_bytes / bytes_per_sample)
    
        if samples%1 != 0:
            raise ValueError(
                "Bytes in binary file did not divide evenly, "
                "incorrect n_chan_bin ('number of channels' in GUI)."
            )
        else:
            return np.int64(samples)
    
    
    class BinaryFileGroup:
        def __init__(self, file_objects=None, filenames=None,
                     n_channels=None, dtype=None):
            # NOTE: Assumes list order of files or objects matches temporal order
            #       for concatenation.
            self._file_objects = None
            self._filenames = None
            self.n_files = 0
    
            if file_objects is not None:
                self._file_objects = file_objects
                self.dtype = file_objects[0].dtype
                self.n_chans = file_objects[0].shape[1]
                self.n_files = len(file_objects)
                for f in file_objects[1:]:
                    assert f.dtype == self.dtype, 'All files must have the same dtype'
                    assert f.shape[1] == self.n_chans, \
                        'All files must have the same number of channels'
            elif filenames is not None:
                # Must specify n_channels, dtype
                if n_channels is None or dtype is None:
                    raise ValueError("BinaryFileGroup requires n_channels and dtype "
                                     "when using filenames option.")
                self._filenames = filenames
                self.dtype = dtype
                self.n_chans = n_channels
                self.n_files = len(filenames)
            else:
                raise ValueError("Must specify either file_objects or filenames")
    
            # Track indices that represent boundary between files. Each entry
            # is the starting index of the subsequent file.
            i = 0
            self.split_indices = []
            for j in range(self.n_files):
                i += self.get_file_size(j)
                self.split_indices.append(i)
    
        def get_file(self, i):
            if self._file_objects is not None:
                file = self._file_objects[i]
            else:
                name = self._filenames[i]
                n_samples = get_total_samples(name, self.n_chans, self.dtype)
                file = np.memmap(name, mode='r', dtype=self.dtype,
                                    shape=(n_samples, self.n_chans))
    
            return file
        
        def get_file_size(self, i):
            if self._file_objects is not None:
                size = self._file_objects[i].shape[0]
            else:
                size = get_total_samples(self._filenames[i], self.n_chans, self.dtype)
    
            return size
    
        def __getitem__(self, *items):
            # Index into appropriate individual object based on index.
            # For indices that span multiple files, index all then concatenate result.
            idx, *crop = items
            if not isinstance(idx, tuple): idx = tuple([idx])
            channel_idx = idx[1] if len(idx) > 1 else slice(None)
            time_idx = idx[0]
    
            # To simplify for loop logic, convert integer index to slice and
            # convert NoneTypes to boundaries of data, and convert negative
            # indices to positive indices.
            if not isinstance(time_idx, slice):
                time_idx = slice(time_idx, time_idx+1)
            i = time_idx.start
            j = time_idx.stop
    
            if i is None: i = 0
            if i < 0: i = self.shape[0] + i
            if j is None: j = self.shape[0]
            if j < 0: j = self.shape[0] + j
            time_idx = slice(i, j)
    
            data = []
            shift = 0
            for idx, k in enumerate(self.split_indices):
                f = self.get_file(idx)
                n_samples = f.shape[0]
                if time_idx.start < k:
                    # At least part of the data is in this file
                    ii = max(int(time_idx.start - shift), 0)
                    jj = max(int(time_idx.stop - shift), 0)
                    t = slice(ii, jj)
                    data.append(f[t, channel_idx])
                    if time_idx.stop <= k:
                        # This is the end of the data to be retrieved
                        break
                shift += n_samples
    
            if len(data) == 0:
                d = None
            elif len(data) == 1:
                d = data[0]
            else:
                d = np.concatenate(data, axis=0)
        
            return d
    
        @property
        def shape(self):
            return self.split_indices[-1], self.n_chans
        
        @staticmethod
        def from_filenames(filenames, n_channels, dtype=np.int16):
            # NOTE: This is just an alias now for the below command, retained
            #       for backwards compatibility.
            return BinaryFileGroup(filenames=filenames, n_channels=n_channels,
                                   dtype=dtype)
    
    
    class BinaryFiltered(BinaryRWFile):
        def __init__(self, filename: str, n_chan_bin: int, fs: int = 30000, 
                     NT: int = 60000, nt: int = 61, nt0min: int = 20,
                     chan_map: np.ndarray = None, hp_filter: torch.Tensor = None,
                     whiten_mat: torch.Tensor = None, dshift: torch.Tensor = None,
                     device: torch.device = None, do_CAR: bool = True,
                     artifact_threshold: float = np.inf, invert_sign: bool = False,
                     dtype=None, tmin: float = 0.0, tmax: float = np.inf,
                     shift=None, scale=None, file_object=None):
    
            super().__init__(filename, n_chan_bin, fs, NT, nt, nt0min, device,
                             dtype=dtype, tmin=tmin, tmax=tmax, shift=shift,
                             scale=scale, file_object=file_object) 
            self.chan_map = chan_map
            self.whiten_mat = whiten_mat
            self.hp_filter = hp_filter
            self.dshift = dshift
            self.do_CAR = do_CAR
            self.invert_sign=invert_sign
            self.artifact_threshold = artifact_threshold
    
        def filter(self, X, ops=None, ibatch=None, skip_preproc=False):
            # pick only the channels specified in the chanMap
            if self.chan_map is not None:
                X = X[self.chan_map]
    
            if self.invert_sign:
                X = X * -1
    
            X = X - X.mean(1).unsqueeze(1)
            if self.do_CAR:
                # remove the mean of each channel, and the median across channels
                X = X - torch.median(X, 0)[0]
        
            if skip_preproc:
                return X
    
            # high-pass filtering in the Fourier domain (much faster than filtfilt etc)
            if self.hp_filter is not None:
                fwav = fft_highpass(self.hp_filter, NT=X.shape[1])
                X = torch.real(ifft(fft(X) * torch.conj(fwav)))
                X = fftshift(X, dim = -1)
    
            if self.artifact_threshold < np.inf:
                if torch.any(torch.abs(X) >= self.artifact_threshold):
                    # Assume the batch contains a recording artifact.
                    # Skip subsequent preprocessing, zero-out the batch.
                    return torch.zeros_like(X)
    
            # whitening, with optional drift correction
            if self.whiten_mat is not None:
                if self.dshift is not None and ops is not None and ibatch is not None:
                    M = get_drift_matrix(ops, self.dshift[ibatch], device=self.device)
                    #logger.info(M.dtype, X.dtype, self.whiten_mat.dtype)
                    X = (M @ self.whiten_mat) @ X
                else:
                    X = self.whiten_mat @ X
            return X
    
        def __getitem__(self, *items):
            samples = super().__getitem__(*items)
            with warnings.catch_warnings():
                # Don't need this, we know about the warning and it doesn't cause
                # any problems. Doing this the "correct" way is much slower.
                warnings.filterwarnings("ignore", message=_torch_warning)
                X = torch.from_numpy(samples.T).to(self.device).float()
            return self.filter(X)
            
        def padded_batch_to_torch(self, ibatch, ops=None, return_inds=False,
                                  skip_preproc=False):
            if return_inds:
                X, inds = super().padded_batch_to_torch(ibatch, return_inds=return_inds)
                return self.filter(X, ops, ibatch, skip_preproc=skip_preproc), inds
            else:
                X = super().padded_batch_to_torch(ibatch)
                return self.filter(X, ops, ibatch, skip_preproc=skip_preproc)
    
    
    def save_preprocessing(filename, ops, bfile=None, bfile_path=None, verbose=False):
        """Save a preprocessed copy of data, including drift correction.
    
        Parameters
        ----------
        filename : str or Path-like.
            Path where new file should be saved.
        ops : dict.
            Settings and state variables used in sorting. See `kilosort.run_kilosort`.
        bfile : BinaryFiltered; optional.
            Binary file loaded as a BinaryFiltered instance, including all variables
            needed for preprocessing (whitening, filtering, and drift correction).
            If specified, `bfile_path` will not be used.
            One of `bfile` or `bfile_path` must be provided.
        bfile_path : str or Path-like.
            Path where raw binary data should be loaded from. If `bfile` is given,
            this parameter will not be used.
            One of `bfile` or `bfile_path` must be provided.
    
        """
    
        n_batches = ops['Nbatches']
        nt = ops['nt']
        NT = ops['batch_size']
        whiten_mat = ops['preprocessing']['whiten_mat']
        hp_filter = ops['preprocessing']['hp_filter']
        dshift = ops['dshift']
        chan_map = ops['chanMap']
        dtype = ops['data_dtype']
        n_chans = ops['n_chan_bin']
    
        if bfile is None:
            if bfile_path is None:
                raise ValueError("Must specify either `bfile` or `bfile_path`.")
            bfile = BinaryFiltered(
                filename=bfile_path, n_chan_bin=n_chans, chan_map=chan_map, nt=nt,
                NT=NT, hp_filter=hp_filter, whiten_mat=whiten_mat, dshift=dshift,
                dtype=dtype
                )
    
        # Need weights to linearly smooth the overlapping portions of batches
        # after drift correction. I.e. first replaced sample is mostly weighted
        # for first batch, middle sample is 50/50, last sample is mostly weighted
        # for second batch.
        n_chans_used = len(chan_map)
        weights = np.linspace(1, 0, 2*nt+2)[1:-1]
        W = np.vstack([weights, np.flip(weights)])
        W = np.tile(W, n_chans_used).reshape(2, n_chans_used, 2*nt)
    
        logger.info(' ')
        logger.info('='*40)
        logger.info(f'Saving drift-corrected copy of data to: {filename}...')
        for i in range(n_batches):
            if i % 100 == 0:
                logger.info(f'Writing batch {i}/{n_batches}...')
            if i > 0 and i % 500 == 0:
                log_performance(logger, level='debug', header=None)
    
            if i == 0:
                # Initialize with first batch
                batch1 = bfile.padded_batch_to_torch(i, ops=ops)
            else:
                # Re-use batch2 from previous iteration
                batch1 = batch2
    
            # NOTE: dtype for new file is always int16, float32 data returned by
            #       preproc steps is scaled by 200 and then converted.
            mode = 'w+' if i == 0 else 'r+'
            z = np.memmap(filename, dtype='int16', mode=mode,
                          shape=(NT*n_batches, n_chans))
    
            if i == n_batches-1:
                # Skip first 2*nt of real data, it was added in previous iter.
                # Nothing to interpolate on last batch.
                y = batch1[:, 2*nt:-nt].cpu().numpy().T
                z[(i*NT)+nt:, chan_map] = (y*200).astype('int16')
            else:
                batch2 = bfile.padded_batch_to_torch(i+1, ops=ops)
    
                # Get interpolated values to replace inter-batch padding, there are
                # 2*nt samples overlapping at the batch edges.
                x1 = batch1[:, (NT-1) + 1:].cpu().numpy()
                x2 = batch2[:, :2*nt].cpu().numpy()
                X = np.vstack([x1[np.newaxis,...], x2[np.newaxis,...]])
                y2 = (X*W).sum(axis=0).T
                
                # Write raw data, leaving out padding and first nt values
                y1 = batch1[:, nt*2:-nt].cpu().numpy().T
                z[(i*NT)+(nt) : ((i+1)*NT), chan_map] = (y1*200).astype('int16')
                if i == 0:
                    # Also need to write first nt values of first batch
                    y0 = batch1[:, nt:2*nt].cpu().numpy().T
                    z[:nt, chan_map] = (y0*200).astype('int16')
    
                # Write interpolated data afterward, to replace the last nt values of
                # first batch and first nt values of the next batch in loop.
                z[((i+1)*NT)-nt : ((i+1)*NT)+nt, chan_map] = (y2*200).astype('int16')
    
            z.flush()
            del(z)
    
        logger.info('='*40)
        logger.info('Copying finished.')
        logger.info(' ')
    
    
    def spikeinterface_to_binary(recording, filepath, data_name='data.bin',
                                 dtype=np.int16, chunksize=300000, export_probe=True,
                                 probe_name='probe.prb', max_workers=None):
        """Save data from a SpikeInterface RecordingExtractor to a binary file.
    
        This function is provided to assist with converting data from other file
        formats to the raw binary format supported by Kilosort4. We do not explicitly
        support any other file format, nor is SpikeInterface a required package for
        using Kilosort4, so future updates to SpikeInterface may not be reflected here.
        If you run into errors and would like help loading your data into Kilosort4,
        please post an issue on the Kilosort4 github.
    
        Parameters
        ----------
        recording : RecordingExtractor.
            A SpikeInterface object containing the recording to be copied.
        filepath : str or Path-like.
            Path to the directory where the binary file (and possibly prb file)
            should be saved.
        data_name : str; default='data.bin'
            Name for the new binary file.
        dtype : type; default=np.int16
            Data type for the new binary file.
        chunksize : int; default=60000.
            Number of samples to copy on each loop. A higher number may speed up
            copying, but will use more memory.
        export_probe : bool; default=True.
            If True, attempt to extract probe information from the recording and
            write it to a .prb file.
        probe_name : str; default='probe.prb'
            Name for the new probe file.
        max_workers : int; optional.
            Maximum number of threads used to execute file i/o.
            Default: min(32, (os.process_cpu_count() or 1) + 4)
            (https://github.com/python/cpython/blob/main/Lib/concurrent/futures/thread.py)
            
        Notes
        -----
        SpikeInterface has its own `recording.save()` method for this purpose
        that supports parallelization. This simpler utility is provided for
        better control over filepath structure, minimal output, and fewer
        dependencies on file format details. However, for very large files, you
        may want to investigate `recording.save` to speed up data copying.
        
        """
    
        filepath = Path(filepath)
        filepath.mkdir(exist_ok=True, parents=True)
        binary_filename = filepath / f'{data_name}'
        probe_filename = filepath / f'{probe_name}'
    
        logger.info('='*40)
        logger.info('Loading recording with SpikeInterface...')
    
        # Using actual data shape is less fragile than relying on .get_num_channels()
        N = recording.get_total_samples()
        logger.info(f'number of samples: {N}')
        c = recording.get_traces(start_frame=0, end_frame=1, segment_index=0).shape[1]
        logger.info(f'number of channels: {c}')
        s = recording.get_num_segments()
        logger.info(f'numbef of segments: {s}')
        fs = recording.get_sampling_frequency()
        logger.info(f'sampling rate: {fs}')
        dtype = recording.get_dtype()
        logger.info(f'dtype: {dtype}')
    
        # Determine start/end indices for each segment
        indices = []
        for k in range(s):
            n = recording.get_num_samples(segment_index=k)
            i = 0 + k*chunksize
            while i < n:
                j = i + chunksize if (i + chunksize) < n else n
                indices.append((i, j, k))
                i += chunksize
    
        # Copy each chunk of data to memmory mapped binary file,
        # use multithreading to speed it up.
        def copy_chunk(memmap, i, j, k):
            t = recording.get_traces(start_frame=i, end_frame=j, segment_index=k)
            memmap[i:j,:] = t
            memmap.flush()
            del(t)
    
        y = np.memmap(binary_filename, dtype=dtype, mode='w+', shape=(N,c))
        total_chunks = len(indices)
        logger.info('='*40)
        logger.info(
            f'Converting {total_chunks} data chunks '
            f'with a chunksize of {chunksize} samples...'
            )
        with ThreadPoolExecutor(max_workers=max_workers) as exe:
            futures = [exe.submit(copy_chunk, y, i, j, k) for i, j, k in indices]
            # TODO: every some-amount-of-time, check list of futures
            #       for completion
            while exe._work_queue.qsize() > 0:
                time.sleep(10)
                logger.info(
                    f'{total_chunks - exe._work_queue.qsize()} of {total_chunks}'
                    ' chunks converted...'
                    )
        logger.info(f'Data conversion finished.')
        logger.info('='*40)
    
        del(y)  # Close memmap after copying
    
        if export_probe:
            try:
                from probeinterface import write_prb
                try:
                    pg = recording.get_probegroup()
                    write_prb(probe_filename, pg)
                except ValueError:
                    logger.info(
                        'SpikeInterface recording contains no probe information,\n'
                        'could not write .prb file.'
                    )
                    probe_filename = None
            except ModuleNotFoundError:
                raise ModuleNotFoundError(
                    'Could not import `write_prb` from probeinterface when exporting '
                    'probe, please run `pip install spikeinterface`.'
                )
        else:
            probe_filename = None
    
        return binary_filename, N, c, s, fs, probe_filename
    
    
    class RecordingExtractorAsArray:
    
        def __init__(self, recording_extractor):
            """An array-like wrapper for a RecordingExtractor.
    
            This class is provided to assist with loading data from other file 
            formats besides the raw binary format supported by Kilosort4. We do not
            explicitly support any other file format, nor is SpikeInterface a
            required package for using Kilosort4, so future updates to SpikeInterface
            may not be reflected here. If you run into errors and would like help
            loading your data into Kilosort4, please post an issue on the Kilosort4
            github.
    
            Parameters
            ----------
            recording_extractor : RecordingExtractor
                A SpikeInterface recording extractor. When the wrapper object is
                indexed, `recording_extractor.get_traces()` will be invoked to
                retrieve data from disk.
    
            Attributes
            ----------
            shape
            dtype
    
            Examples
            --------
            >>> from spikeinterface.extractors import read_nwb_recording
            >>> recording = read_nwb_recording('/home/my_file_path/arange.nwb')
            >>> as_array = RecordingExtractorAsArray(recording)
            >>> as_array[:5, 2]
            array([[0],
                   [1],
                   [2],
                   [3],
                   [4],
                   [5]])
    
            """
    
            if recording_extractor.get_num_segments() > 1:
                try:
                    import spikeinterface as si
                    self.recording = si.concatenate_recordings([recording_extractor])
                    logger.info(
                        'SpikeInterface recording contains more than one segment, '
                        'segments will be concatenated as if contiguous.'
                        )
                except ModuleNotFoundError:
                    raise ModuleNotFoundError(
                        'SpikeInterface could not be imported, but is needed for '
                        'loading multi-segment SpikeInterface recordings. Please '
                        'run `pip install spikeinterface`.'
                    )
    
            logger.info('='*40)
            logger.info('Loading recording with SpikeInterface...')
            self.recording = recording_extractor
            self.N = self.recording.get_total_samples()
            logger.info(f'number of samples: {self.N}')
            self.c = self.recording.get_traces(start_frame=0, end_frame=1, segment_index=0).shape[1]
            logger.info(f'number of channels: {self.c}')
            self.s = self.recording.get_num_segments()
            logger.info(f'numbef of segments: {self.s}')
            self.fs = self.recording.get_sampling_frequency()
            logger.info(f'sampling rate: {self.fs}')
            self.dtype = self.recording.get_dtype()
            logger.info(f'dtype: {self.dtype}')
            self.shape = (self.N, self.c)
            logger.info('='*40)
        
    
        def __getitem__(self, *items):
            idx, *crop = items
            if not isinstance(idx, tuple): idx = tuple([idx])
            sample_idx = idx[0]
            channel_ids = None if len(idx) == 1 else idx[1]
    
            # Convert integer index to slice and convert NoneTypes to boundaries of
            # data, and convert negative indices to positive indices.
            if not isinstance(sample_idx, slice):
                sample_idx = slice(sample_idx, sample_idx+1)
            i = sample_idx.start
            j = sample_idx.stop
    
            if i is None: i = 0
            if i < 0: i = self.N + i
            if j is None: j = self.N
            if j < 0: j = self.N + j
    
            # Convert channel slice to list of indices starting from 0
            if isinstance(channel_ids, slice):
                c = channel_ids.start
                d = channel_ids.stop
                if c is None: c = 0
                if d is None: d = self.shape[1]
                channel_ids = list(range(c, d))
            elif channel_ids is not None:
                assert isinstance(channel_ids, int)
                channel_ids = [channel_ids]
            else:
                channel_ids = np.arange(self.shape[1])
            # Index into actual channel ids from recording, which do not have to 
            # be sequential or start from 0
            channel_ids = self.recording.channel_ids[channel_ids]
    
            samples = self.recording.get_traces(start_frame=i, end_frame=j,
                                                channel_ids=channel_ids)
            
            return samples
    
        def __setitem__(self):
            raise ValueError('RecordingExtractorAsBinary is read-only.')
    

    back to top

    Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
    The source code of Software Heritage itself is available on our development forge.
    The source code files archived by Software Heritage are available under their own copyright and licenses.
    Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API