https://doi.org/10.5281/zenodo.3597474
benchmark.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reproducing Kilosort4 benchmarks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Download data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, choose a simulated dataset to use for the comparison. You'll need to download and extract the data before running the rest of the notebook. For this demo, I chose the medium drift dataset and saved it to `D:/sim_med_drift`.\n",
"\n",
"All datasets: https://janelia.figshare.com/articles/dataset/Simulations_from_kilosort4_paper/25298815/1\n",
"\n",
"Medium drift: https://janelia.figshare.com/ndownloader/files/44729869"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Decompress data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You will also need to decompress the binary file, which has been compressed with `mtscomp`. You can see details about the package here: https://github.com/int-brain-lab/mtscomp\n",
"\n",
"Install with:\n",
"```\n",
"pip install mtscomp\n",
"```\n",
"Then decompress the .cbin file in the unzipped directory downloaded in the previous step:\n",
"```\n",
"mtsdecomp sim.imec0.ap.cbin -o data.bin\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Get probe information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The probe file for the recording can be parsed from `sim.imec0.ap.meta` using a Spike GLX script available here:\n",
"\n",
"https://github.com/jenniferColonell/SGLXMetaToCoords/blob/main/SGLXMetaToCoords.py\n",
"\n",
"Simply download the script and run it in your Kilosort environment with `python SGLXMetaToCoords.py`, then follow the prompt to select the `.meta` file. This will save the probe file in the same directory as `sim.imec0.ap_kilosortChanMap.mat`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Run Kilosort4"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from kilosort import run_kilosort\n",
"from kilosort.io import load_probe, load_ops\n",
"from kilosort.bench import load_GT, load_phy, compare_recordings\n",
"\n",
"# Specify paths\n",
"# NOTE: Make sure to update `download_path` to match your local directory.\n",
"download_path = Path('D:/sim_med_drift')\n",
"results_dir = download_path / 'kilosort4'\n",
"filename = download_path / 'data.bin'\n",
"probe_path = download_path / 'sim.imec0.ap_kilosortChanMap.mat'\n",
"gt_path = download_path / 'sim.imec0.ap_params.npz'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run Kilosort4 on downloaded dataset\n",
"# You can also run it without drift correction by setting `nblocks = 0`,\n",
"# or only rigid drift correction by setting `nblocks = 1`.\n",
"probe = load_probe(probe_path)\n",
"settings = {'n_chan_bin': 385, 'filename': filename, 'nblocks': 5}\n",
"_ = run_kilosort(settings, probe=probe)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Load `ops`\n",
"# NOTE: If you've already sorted the data, you can skip the previous step.\n",
"ops = load_ops(results_dir / 'ops.npy')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Compute ground truth comparison"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 600/600 [04:25<00:00, 2.26it/s]\n"
]
}
],
"source": [
"# NOTE: Using this with the current version of Kilosort4 requires updates\n",
"# to `kilosort.bench.py`` added in v4.0.32\n",
"\n",
"# Load ground truth results\n",
"st_gt, clu_gt, yclu_gt, mu_gt, Wsub, nsp = load_GT(filename, ops, gt_path)\n",
"# Load sorting results\n",
"st_new, clu_new, yclu_new, Wsub = load_phy(filename, results_dir, ops)\n",
"# Compare\n",
"fmax, fmiss, fpos, best_ind, matched_all, top_inds = \\\n",
" compare_recordings(st_gt, clu_gt, yclu_gt, st_new, clu_new, yclu_new)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number correct: 553\n",
"number missed: 47\n"
]
}
],
"source": [
"def num_correct(fpos, fmiss):\n",
" score = 1 - fpos - fmiss\n",
" num_correct = (score >= 0.8).sum()\n",
" return num_correct\n",
"\n",
"n = num_correct(fpos, fmiss)\n",
"print(f'number correct: {n}')\n",
"print(f'number missed: {600 - n}') # 600 ground truth units\n",
"print(f'number correct in paper: 555') # for medium drift dataset, from fig 4j"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "kilosort",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 2
}