Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

  • 6fe46d5
  • /
  • tests
  • /
  • test_io.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
content badge
swh:1:cnt:492bff1e6ff890324de07ee041651d8b8315dcc0
directory badge
swh:1:dir:8f6b71cce4bd4ec2c0cf114e5047b7cb070569b0

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
test_io.py
import pytest
import tempfile
from pathlib import Path

import numpy as np
import torch

from kilosort import io


def test_probe_io():
    # Create one-column probe with 5 contacts, spaced 1um apart.
    json_probe = {
        'chanMap': np.arange(5),
        'xc': np.ones(5),
        'yc': np.arange(5),
        'kcoords': np.zeros(5),
        'n_chan': 5
    }
    # Repeat in .prb format
    prb_probe = """
channel_groups = {
    0: {
            'channels' : [0,1,2,3,4],
            'geometry': {
                0: [1, 0],
                1: [1, 1],
                2: [1, 2],
                3: [1, 3],
                4: [1, 4]
            }
    }
}
"""
    
    # Save both to file
    with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
        json_file = Path(f.name)
        print(json_file)
        io.save_probe(json_probe, json_file)

    with tempfile.NamedTemporaryFile(mode='w', suffix='.prb', delete=False) as f:
        f.write(prb_probe)
        prb_file = Path(f.name)

    # Load both with kilosort.io
    probe1 = io.load_probe(json_file)
    probe2 = io.load_probe(prb_file)

    print('probe1:')
    print(probe1)
    print('probe2:')
    print(probe2)

    try:
        # Verify that loaded probes contain the same information
        for k in ['chanMap', 'xc', 'yc', 'kcoords', 'n_chan']:
            print(f'testing key {k}')
            assert (k in probe1) and (k in probe2)
            assert np.all(probe1[k] == probe2[k])
    finally:
        # Remove temporary files
        json_file.unlink()
        prb_file.unlink()


def test_bad_channels():
    probe = {
        'xc': np.zeros(5), 'yc': np.arange(5)*10, 'kcoords': np.zeros(5),
        'chanMap': np.arange(5), 'n_chan': 5
        }
    bad_channels = [2, 4]
    probe2 = io.remove_bad_channels(probe, bad_channels)
    assert probe2['n_chan'] == 3
    for k in ['xc', 'yc', 'kcoords']:
        assert probe2[k].size == 3
    assert np.all(probe2['chanMap'] == [0, 1, 3])
    assert np.all(probe['chanMap'] == [0, 1, 2, 3, 4])  # original unchanged

    with pytest.raises(IndexError):
        # These are not in the channel map.
        _ = io.remove_bad_channels(probe, [5, 6])


def test_bat_extension(torch_device, data_directory):
    # Create memmap, write to file, close the file again.
    path = data_directory / 'binary_test' / 'temp_memmap.bat'
    path.parent.mkdir(parents=True, exist_ok=True)
    N, C = (1000, 10)
    r = np.random.rand(N,C)*2 - 1    # scale to (-1,1)
    r = (r*(2**14)).astype(np.int16)     # scale up

    try:
        a = np.memmap(path, mode='w+', shape=(N,C), dtype=np.int16)
        a[:] = r[:]
        a.flush()
        del(a)

        directory = path.parent
        filename = io.find_binary(directory)
        assert filename == path
        bfile = io.BinaryFiltered(filename, C, device=torch_device)
        x = bfile[0:100]  # Test data retrieval

    finally:
        # Delete memmap file and re-raise exception
        path.unlink()


def test_dat_extension(torch_device, data_directory):
    # Create memmap, write to file, close the file again.
    path = data_directory / 'binary_test' / 'temp_memmap.dat'
    path.parent.mkdir(parents=True, exist_ok=True)
    N, C = (1000, 10)
    r = np.random.rand(N,C)*2 - 1    # scale to (-1,1)
    r = (r*(2**14)).astype(np.int16)     # scale up

    try:
        a = np.memmap(path, mode='w+', shape=(N,C), dtype=np.int16)
        a[:] = r[:]
        a.flush()
        del(a)

        directory = path.parent
        filename = io.find_binary(directory)
        assert filename == path
        bfile = io.BinaryFiltered(filename, C, device=torch_device)
        x = bfile[0:100]  # Test data retrieval

    finally:
        # Delete memmap file and re-raise exception
        path.unlink()


def test_tmin_tmax(torch_device, data_directory):
    N, C = (1000, 10)
    NT = 300
    nt = 61
    data = np.repeat(np.arange(N)[...,np.newaxis], repeats=C, axis=1)
    fs = 10
    path = data_directory / 'time_interval_test' / 'temp_memmap.dat'
    path.parent.mkdir(parents=True, exist_ok=True)

    try:
        a = np.memmap(path, mode='w+', shape=(N,C), dtype=np.int16)
        a[:] = data[:]
        a.flush()
        del(a)

        bfile = io.BinaryRWFile(path, n_chan_bin=C, fs=fs, device=torch_device,
                                tmin=10, tmax=85, NT=NT, nt=nt)

        assert bfile.imin == 100
        assert bfile.imax == 850
        assert bfile.n_samples == 750
        assert bfile[0:50].min() == 100
        assert bfile[700:750].max() == 849
        assert bfile.n_batches == 3

        X0 = bfile.padded_batch_to_torch(ibatch=0)
        assert X0.min() == 100
        assert X0.max() == 100 + NT + nt - 1

        X1 = bfile.padded_batch_to_torch(ibatch=1)
        assert X1.min() == 100 + NT - nt
        assert X1.max() == 100 + 2*NT + nt - 1

        X2 = bfile.padded_batch_to_torch(ibatch=2)
        assert X2.min() == 100 + 2*NT - nt
        assert X2.max() == 849

    finally:
        # Delete memmap file and re-raise exception
        path.unlink()


def test_tmin_only(torch_device, data_directory):
    N, C = (1000, 10)
    NT = 300
    nt = 61
    data = np.repeat(np.arange(N)[...,np.newaxis], repeats=C, axis=1)
    fs = 10
    path = data_directory / 'time_interval_test' / 'temp_memmap2.dat'
    path.parent.mkdir(parents=True, exist_ok=True)

    try:
        a = np.memmap(path, mode='w+', shape=(N,C), dtype=np.int16)
        a[:] = data[:]
        a.flush()
        del(a)

        bfile = io.BinaryRWFile(path, n_chan_bin=C, fs=fs, device=torch_device,
                                tmin=43, NT=NT, nt=nt)

        assert bfile.imin == 430
        assert bfile.imax == 1000
        assert bfile.n_samples == 570
        assert bfile[0:10].min() == 430
        assert bfile[400:].max() == 999
        assert bfile.n_batches == 2

    finally:
        # Delete memmap file and re-raise exception
        path.unlink()


def test_tmax_only(torch_device, data_directory):
    N, C = (1000, 10)
    NT = 300
    nt = 61
    data = np.repeat(np.arange(N)[...,np.newaxis], repeats=C, axis=1)
    fs = 10
    path = data_directory / 'time_interval_test' / 'temp_memmap3.dat'
    path.parent.mkdir(parents=True, exist_ok=True)

    try:
        a = np.memmap(path, mode='w+', shape=(N,C), dtype=np.int16)
        a[:] = data[:]
        a.flush()
        del(a)

        bfile = io.BinaryRWFile(path, n_chan_bin=C, fs=fs, device=torch_device,
                                tmax=78, NT=NT, nt=nt)

        assert bfile.imin == 0
        assert bfile.imax == 780
        assert bfile.n_samples == 780
        assert bfile[0:100].min() == 0
        assert bfile[700:].max() == 779
        assert bfile.n_batches == 3

    finally:
        # Delete memmap file and re-raise exception
        path.unlink()


def test_file_group(torch_device, data_directory, bfile):
    file = data_directory / 'ZFM-02370_mini.imec0.ap.short.bin'
    fs = bfile.fs
    n_chans = bfile.n_chan_bin

    # Test with file_objects option
    # Load as three 15-second files instead of one 45-second file.
    objs = [np.memmap(file, dtype='int16', shape=bfile.shape, mode='r')
            for _ in range(3)]
    objs[0] = objs[0][:int(15*fs),:]
    objs[1] = objs[1][int(15*fs):int(30*fs),:]
    objs[2] = objs[2][int(30*fs):,:]
    bfg = io.BinaryFileGroup(file_objects=objs)
    bfile2 = io.BinaryFiltered(
        filename='test', n_chan_bin=n_chans, fs=fs, chan_map=bfile.chan_map,
        device=torch_device, file_object=bfg, dtype='int16'
        )

    # First batch, overlapping batch, and last batch
    # (assumes 45s test dataset with 2s batch size)
    for i in [0, 7, 22]:
        b1 = bfile.padded_batch_to_torch(i, skip_preproc=True)
        b2 = bfile2.padded_batch_to_torch(i)
        assert torch.allclose(b1, b2)

    # Test with filenames option
    files = [file]*3  # Load the same data three times
    bfile3 = io.BinaryFiltered(
        filename=files, n_chan_bin=n_chans, fs=fs, chan_map=bfile.chan_map,
        device=torch_device, dtype='int16'
    )

    # First and first, last and last, last of original and last of concat
    for i,j in [(0,0), (21,21), (22,67)]:
        b1 = bfile.padded_batch_to_torch(i, skip_preproc=True)
        b2 = bfile3.padded_batch_to_torch(j)
        assert torch.allclose(b1, b2)

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API