Raw File
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aa21cf58",
   "metadata": {},
   "source": [
    "# DeepLabCut Process Position\n",
    "\n",
    "## Dhruv Mehrotra, 2022\n",
    "\n",
    "\n",
    "In this notebook, we will learn how to analyze position data from a given sub-region of your environment. In particular, this example deals with a mouse running on the radial arm maze, but the idea can be generalized to the analysis of a given sub-region of any environment. \n",
    "\n",
    "\n",
    "Let's get right into it! First, import the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ee9c3d76",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pynapple as nap\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.io\n",
    "import os, sys\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from pylab import *\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1440e2c9",
   "metadata": {},
   "source": [
    "Next, load the data from your directory. My data is being read from an h5 file, but this can be replaced to read whatever format you are working with (csv, MAT file etc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "39db3389",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>scorer</th>\n",
       "      <th colspan=\"9\" halign=\"left\">DLC_mobnet_100_unimplantedJan4shuffle1_200000</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>bodyparts</th>\n",
       "      <th colspan=\"3\" halign=\"left\">nose</th>\n",
       "      <th colspan=\"3\" halign=\"left\">leftear</th>\n",
       "      <th colspan=\"3\" halign=\"left\">rightear</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>coords</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>likelihood</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>likelihood</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>likelihood</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>325.915680</td>\n",
       "      <td>87.834717</td>\n",
       "      <td>0.047951</td>\n",
       "      <td>327.809082</td>\n",
       "      <td>89.872162</td>\n",
       "      <td>0.031015</td>\n",
       "      <td>383.521667</td>\n",
       "      <td>158.932144</td>\n",
       "      <td>0.019942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>330.679016</td>\n",
       "      <td>92.426193</td>\n",
       "      <td>0.005103</td>\n",
       "      <td>111.054222</td>\n",
       "      <td>1007.833008</td>\n",
       "      <td>0.006857</td>\n",
       "      <td>384.808868</td>\n",
       "      <td>158.491882</td>\n",
       "      <td>0.008222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>330.229675</td>\n",
       "      <td>92.469246</td>\n",
       "      <td>0.003538</td>\n",
       "      <td>991.470459</td>\n",
       "      <td>573.574402</td>\n",
       "      <td>0.003131</td>\n",
       "      <td>384.428528</td>\n",
       "      <td>158.548752</td>\n",
       "      <td>0.012223</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>324.419312</td>\n",
       "      <td>85.861443</td>\n",
       "      <td>0.028133</td>\n",
       "      <td>328.421967</td>\n",
       "      <td>90.776245</td>\n",
       "      <td>0.030893</td>\n",
       "      <td>384.233154</td>\n",
       "      <td>157.995850</td>\n",
       "      <td>0.030514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>324.482239</td>\n",
       "      <td>87.202263</td>\n",
       "      <td>0.014028</td>\n",
       "      <td>327.500519</td>\n",
       "      <td>90.704430</td>\n",
       "      <td>0.016099</td>\n",
       "      <td>384.890259</td>\n",
       "      <td>158.675751</td>\n",
       "      <td>0.016726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154432</th>\n",
       "      <td>991.857971</td>\n",
       "      <td>175.494125</td>\n",
       "      <td>0.001959</td>\n",
       "      <td>423.084625</td>\n",
       "      <td>199.708374</td>\n",
       "      <td>0.001155</td>\n",
       "      <td>384.718536</td>\n",
       "      <td>157.750381</td>\n",
       "      <td>0.002946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154433</th>\n",
       "      <td>991.476501</td>\n",
       "      <td>174.963913</td>\n",
       "      <td>0.001489</td>\n",
       "      <td>991.325317</td>\n",
       "      <td>573.928833</td>\n",
       "      <td>0.001893</td>\n",
       "      <td>384.466309</td>\n",
       "      <td>157.800858</td>\n",
       "      <td>0.005548</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154434</th>\n",
       "      <td>991.720520</td>\n",
       "      <td>175.118958</td>\n",
       "      <td>0.001419</td>\n",
       "      <td>417.429535</td>\n",
       "      <td>191.197449</td>\n",
       "      <td>0.001276</td>\n",
       "      <td>386.694214</td>\n",
       "      <td>157.707367</td>\n",
       "      <td>0.031571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154435</th>\n",
       "      <td>328.655090</td>\n",
       "      <td>86.512009</td>\n",
       "      <td>0.001663</td>\n",
       "      <td>314.363403</td>\n",
       "      <td>-13.260569</td>\n",
       "      <td>0.001334</td>\n",
       "      <td>384.249390</td>\n",
       "      <td>158.987778</td>\n",
       "      <td>0.011806</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154436</th>\n",
       "      <td>991.879517</td>\n",
       "      <td>174.817444</td>\n",
       "      <td>0.001389</td>\n",
       "      <td>422.413788</td>\n",
       "      <td>196.893265</td>\n",
       "      <td>0.002721</td>\n",
       "      <td>386.354828</td>\n",
       "      <td>157.463684</td>\n",
       "      <td>0.075369</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>154437 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "scorer    DLC_mobnet_100_unimplantedJan4shuffle1_200000              \\\n",
       "bodyparts                                          nose               \n",
       "coords                                                x           y   \n",
       "0                                            325.915680   87.834717   \n",
       "1                                            330.679016   92.426193   \n",
       "2                                            330.229675   92.469246   \n",
       "3                                            324.419312   85.861443   \n",
       "4                                            324.482239   87.202263   \n",
       "...                                                 ...         ...   \n",
       "154432                                       991.857971  175.494125   \n",
       "154433                                       991.476501  174.963913   \n",
       "154434                                       991.720520  175.118958   \n",
       "154435                                       328.655090   86.512009   \n",
       "154436                                       991.879517  174.817444   \n",
       "\n",
       "scorer                                                                \\\n",
       "bodyparts                leftear                            rightear   \n",
       "coords    likelihood           x            y likelihood           x   \n",
       "0           0.047951  327.809082    89.872162   0.031015  383.521667   \n",
       "1           0.005103  111.054222  1007.833008   0.006857  384.808868   \n",
       "2           0.003538  991.470459   573.574402   0.003131  384.428528   \n",
       "3           0.028133  328.421967    90.776245   0.030893  384.233154   \n",
       "4           0.014028  327.500519    90.704430   0.016099  384.890259   \n",
       "...              ...         ...          ...        ...         ...   \n",
       "154432      0.001959  423.084625   199.708374   0.001155  384.718536   \n",
       "154433      0.001489  991.325317   573.928833   0.001893  384.466309   \n",
       "154434      0.001419  417.429535   191.197449   0.001276  386.694214   \n",
       "154435      0.001663  314.363403   -13.260569   0.001334  384.249390   \n",
       "154436      0.001389  422.413788   196.893265   0.002721  386.354828   \n",
       "\n",
       "scorer                            \n",
       "bodyparts                         \n",
       "coords              y likelihood  \n",
       "0          158.932144   0.019942  \n",
       "1          158.491882   0.008222  \n",
       "2          158.548752   0.012223  \n",
       "3          157.995850   0.030514  \n",
       "4          158.675751   0.016726  \n",
       "...               ...        ...  \n",
       "154432     157.750381   0.002946  \n",
       "154433     157.800858   0.005548  \n",
       "154434     157.707367   0.031571  \n",
       "154435     158.987778   0.011806  \n",
       "154436     157.463684   0.075369  \n",
       "\n",
       "[154437 rows x 9 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_directory = '/media/DataDhruv/Recordings/unimplanted/211231'\n",
    "tracking_data =  pd.read_hdf(data_directory + '/' + '1819-211231_1.h5', mode = 'r')\n",
    "\n",
    "tracking_data\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3cdd606",
   "metadata": {},
   "source": [
    "Here, we see that tracking_data has 9 columns. Namely, the x and y positions of the 3 bodyparts I labelled (nose, left ear and right ear), as well as the likelihood. We are only interested in the x and y positions, so we will extract these, and create a new DataFrame, with just the relevant data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e8dba974",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>scorer</th>\n",
       "      <th colspan=\"6\" halign=\"left\">DLC_mobnet_100_unimplantedJan4shuffle1_200000</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>bodyparts</th>\n",
       "      <th colspan=\"2\" halign=\"left\">nose</th>\n",
       "      <th colspan=\"2\" halign=\"left\">leftear</th>\n",
       "      <th colspan=\"2\" halign=\"left\">rightear</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>coords</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>325.915680</td>\n",
       "      <td>87.834717</td>\n",
       "      <td>327.809082</td>\n",
       "      <td>89.872162</td>\n",
       "      <td>383.521667</td>\n",
       "      <td>158.932144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>330.679016</td>\n",
       "      <td>92.426193</td>\n",
       "      <td>111.054222</td>\n",
       "      <td>1007.833008</td>\n",
       "      <td>384.808868</td>\n",
       "      <td>158.491882</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>330.229675</td>\n",
       "      <td>92.469246</td>\n",
       "      <td>991.470459</td>\n",
       "      <td>573.574402</td>\n",
       "      <td>384.428528</td>\n",
       "      <td>158.548752</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>324.419312</td>\n",
       "      <td>85.861443</td>\n",
       "      <td>328.421967</td>\n",
       "      <td>90.776245</td>\n",
       "      <td>384.233154</td>\n",
       "      <td>157.995850</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>324.482239</td>\n",
       "      <td>87.202263</td>\n",
       "      <td>327.500519</td>\n",
       "      <td>90.704430</td>\n",
       "      <td>384.890259</td>\n",
       "      <td>158.675751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154432</th>\n",
       "      <td>991.857971</td>\n",
       "      <td>175.494125</td>\n",
       "      <td>423.084625</td>\n",
       "      <td>199.708374</td>\n",
       "      <td>384.718536</td>\n",
       "      <td>157.750381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154433</th>\n",
       "      <td>991.476501</td>\n",
       "      <td>174.963913</td>\n",
       "      <td>991.325317</td>\n",
       "      <td>573.928833</td>\n",
       "      <td>384.466309</td>\n",
       "      <td>157.800858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154434</th>\n",
       "      <td>991.720520</td>\n",
       "      <td>175.118958</td>\n",
       "      <td>417.429535</td>\n",
       "      <td>191.197449</td>\n",
       "      <td>386.694214</td>\n",
       "      <td>157.707367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154435</th>\n",
       "      <td>328.655090</td>\n",
       "      <td>86.512009</td>\n",
       "      <td>314.363403</td>\n",
       "      <td>-13.260569</td>\n",
       "      <td>384.249390</td>\n",
       "      <td>158.987778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154436</th>\n",
       "      <td>991.879517</td>\n",
       "      <td>174.817444</td>\n",
       "      <td>422.413788</td>\n",
       "      <td>196.893265</td>\n",
       "      <td>386.354828</td>\n",
       "      <td>157.463684</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>154437 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "scorer    DLC_mobnet_100_unimplantedJan4shuffle1_200000              \\\n",
       "bodyparts                                          nose               \n",
       "coords                                                x           y   \n",
       "0                                            325.915680   87.834717   \n",
       "1                                            330.679016   92.426193   \n",
       "2                                            330.229675   92.469246   \n",
       "3                                            324.419312   85.861443   \n",
       "4                                            324.482239   87.202263   \n",
       "...                                                 ...         ...   \n",
       "154432                                       991.857971  175.494125   \n",
       "154433                                       991.476501  174.963913   \n",
       "154434                                       991.720520  175.118958   \n",
       "154435                                       328.655090   86.512009   \n",
       "154436                                       991.879517  174.817444   \n",
       "\n",
       "scorer                                                      \n",
       "bodyparts     leftear                 rightear              \n",
       "coords              x            y           x           y  \n",
       "0          327.809082    89.872162  383.521667  158.932144  \n",
       "1          111.054222  1007.833008  384.808868  158.491882  \n",
       "2          991.470459   573.574402  384.428528  158.548752  \n",
       "3          328.421967    90.776245  384.233154  157.995850  \n",
       "4          327.500519    90.704430  384.890259  158.675751  \n",
       "...               ...          ...         ...         ...  \n",
       "154432     423.084625   199.708374  384.718536  157.750381  \n",
       "154433     991.325317   573.928833  384.466309  157.800858  \n",
       "154434     417.429535   191.197449  386.694214  157.707367  \n",
       "154435     314.363403   -13.260569  384.249390  158.987778  \n",
       "154436     422.413788   196.893265  386.354828  157.463684  \n",
       "\n",
       "[154437 rows x 6 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hd_cols = [0,1,3,4,6,7]\n",
    "hd_data = tracking_data.iloc[:,hd_cols]\n",
    "\n",
    "hd_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f97b3f34",
   "metadata": {},
   "source": [
    "Now, from this data shall compute the centroid of these 3 body parts. This will result in a proxy for head position.\n",
    "The centroid or geometric center is the arithmetic mean position of all the points (in our case, the ears and the nose).\n",
    "\n",
    "First, we select the columns containing the x-and y-coordinates of all parts. We will store these in x_cols and y_cols, respectively. Then create a new DataFrame with just the x-coordinate and y-coordinates, which we call \"all_x_coords\" and \"all_y_coords\" respectively. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2f51bf19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>scorer</th>\n",
       "      <th colspan=\"3\" halign=\"left\">DLC_mobnet_100_unimplantedJan4shuffle1_200000</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>bodyparts</th>\n",
       "      <th>nose</th>\n",
       "      <th>leftear</th>\n",
       "      <th>rightear</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>coords</th>\n",
       "      <th>x</th>\n",
       "      <th>x</th>\n",
       "      <th>x</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>325.915680</td>\n",
       "      <td>327.809082</td>\n",
       "      <td>383.521667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>330.679016</td>\n",
       "      <td>111.054222</td>\n",
       "      <td>384.808868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>330.229675</td>\n",
       "      <td>991.470459</td>\n",
       "      <td>384.428528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>324.419312</td>\n",
       "      <td>328.421967</td>\n",
       "      <td>384.233154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>324.482239</td>\n",
       "      <td>327.500519</td>\n",
       "      <td>384.890259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154432</th>\n",
       "      <td>991.857971</td>\n",
       "      <td>423.084625</td>\n",
       "      <td>384.718536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154433</th>\n",
       "      <td>991.476501</td>\n",
       "      <td>991.325317</td>\n",
       "      <td>384.466309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154434</th>\n",
       "      <td>991.720520</td>\n",
       "      <td>417.429535</td>\n",
       "      <td>386.694214</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154435</th>\n",
       "      <td>328.655090</td>\n",
       "      <td>314.363403</td>\n",
       "      <td>384.249390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154436</th>\n",
       "      <td>991.879517</td>\n",
       "      <td>422.413788</td>\n",
       "      <td>386.354828</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>154437 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "scorer    DLC_mobnet_100_unimplantedJan4shuffle1_200000              \\\n",
       "bodyparts                                          nose     leftear   \n",
       "coords                                                x           x   \n",
       "0                                            325.915680  327.809082   \n",
       "1                                            330.679016  111.054222   \n",
       "2                                            330.229675  991.470459   \n",
       "3                                            324.419312  328.421967   \n",
       "4                                            324.482239  327.500519   \n",
       "...                                                 ...         ...   \n",
       "154432                                       991.857971  423.084625   \n",
       "154433                                       991.476501  991.325317   \n",
       "154434                                       991.720520  417.429535   \n",
       "154435                                       328.655090  314.363403   \n",
       "154436                                       991.879517  422.413788   \n",
       "\n",
       "scorer                 \n",
       "bodyparts    rightear  \n",
       "coords              x  \n",
       "0          383.521667  \n",
       "1          384.808868  \n",
       "2          384.428528  \n",
       "3          384.233154  \n",
       "4          384.890259  \n",
       "...               ...  \n",
       "154432     384.718536  \n",
       "154433     384.466309  \n",
       "154434     386.694214  \n",
       "154435     384.249390  \n",
       "154436     386.354828  \n",
       "\n",
       "[154437 rows x 3 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_cols = [0,2,4]\n",
    "y_cols = [1,3,5]\n",
    "all_x_coords = hd_data.iloc[:,x_cols]\n",
    "all_y_coords = hd_data.iloc[:,y_cols]\n",
    "all_x_coords"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55350cd0",
   "metadata": {},
   "source": [
    "Wonderful! Now to compute the centroid. Remember, it is a mean. Therefore, we need the sum of the observations for each time point. We compute the sums for each coordinate separately, below:  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5deb543f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0         1037.246429\n",
       "1          826.542107\n",
       "2         1706.128662\n",
       "3         1037.074432\n",
       "4         1036.873016\n",
       "             ...     \n",
       "154432    1799.661133\n",
       "154433    2367.268127\n",
       "154434    1795.844269\n",
       "154435    1027.267883\n",
       "154436    1800.648132\n",
       "Length: 154437, dtype: float64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_sum = all_x_coords.sum(axis = 1)\n",
    "y_sum = all_y_coords.sum(axis = 1)\n",
    "\n",
    "x_sum"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80633a53",
   "metadata": {},
   "source": [
    "Now, we need to divide the sum by the number of body parts. In our case this is 3, but we will express this in more general terms so that we do not hard-code our variables. Remember, it is good practice to write the most generalizable code. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c4408a51",
   "metadata": {},
   "outputs": [],
   "source": [
    "length = all_x_coords.iloc[0,:].shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94ba2ab9",
   "metadata": {},
   "source": [
    "Now, we will compute the centroid, and store it in a DataFrame called hd_centroid. From here, we will extract the X and Y positions of the head, stored in the variables x and y, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ad98bac4",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cent = x_sum/length\n",
    "y_cent = y_sum/length\n",
    "\n",
    "hd_centroid = np.zeros((len(x_cent),2))\n",
    "hd_centroid[:,0] = x_cent \n",
    "hd_centroid[:,1] = y_cent\n",
    "\n",
    "x = hd_centroid[:,0]\n",
    "y = hd_centroid[:,1]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c13006fd",
   "metadata": {},
   "source": [
    "Now, let us create a DataFrame for position. Time to use Pynapple! This recording was acquired at 120Hz, so we will make the timestamps first, as below: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0c1f62ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 120\n",
    "timestamps = x_cent.index.values/fs\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3c79d4a",
   "metadata": {},
   "source": [
    "Now, we create the position DataFrame as below: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5bbdb539",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Time (s)</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0.000000</th>\n",
       "      <td>345.748810</td>\n",
       "      <td>112.213008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.008333</th>\n",
       "      <td>275.514036</td>\n",
       "      <td>419.583694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.016667</th>\n",
       "      <td>568.709554</td>\n",
       "      <td>274.864133</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.025000</th>\n",
       "      <td>345.691477</td>\n",
       "      <td>111.544512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.033333</th>\n",
       "      <td>345.624339</td>\n",
       "      <td>112.194148</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1286.933333</th>\n",
       "      <td>599.887044</td>\n",
       "      <td>177.650960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1286.941667</th>\n",
       "      <td>789.089376</td>\n",
       "      <td>302.231201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1286.950000</th>\n",
       "      <td>598.614756</td>\n",
       "      <td>174.674591</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1286.958333</th>\n",
       "      <td>342.422628</td>\n",
       "      <td>77.413073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1286.966667</th>\n",
       "      <td>600.216044</td>\n",
       "      <td>176.391464</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>154437 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                      x           y\n",
       "Time (s)                           \n",
       "0.000000     345.748810  112.213008\n",
       "0.008333     275.514036  419.583694\n",
       "0.016667     568.709554  274.864133\n",
       "0.025000     345.691477  111.544512\n",
       "0.033333     345.624339  112.194148\n",
       "...                 ...         ...\n",
       "1286.933333  599.887044  177.650960\n",
       "1286.941667  789.089376  302.231201\n",
       "1286.950000  598.614756  174.674591\n",
       "1286.958333  342.422628   77.413073\n",
       "1286.966667  600.216044  176.391464\n",
       "\n",
       "[154437 rows x 2 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "position = np.vstack([x, y]).T\n",
    "position = nap.TsdFrame(t = timestamps, d = position, columns = ['x', 'y'], time_units = 's')\n",
    "\n",
    "position\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "873c444a",
   "metadata": {},
   "source": [
    "This looks good, but does not give us an idea of what the data really represents. Let's plot this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "59aba406",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.scatterplot(data = position, x = x, y = y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c30bd4c7",
   "metadata": {},
   "source": [
    "Aha! What we said before is now evident. The position plot reveals the radial arm maze outline. There are some points outside the bounds of the maze, due to DeepLabCut detecting the mouse erroneously. But these point are irrelevant to us, since we only care about points in the centre of the maze. Now, I will define the centre point of the maze, using the coordinates (xth, yth) and a radius of the circle given by rth.\n",
    "\n",
    "(xth, yth) are approximate values for the centre of this maze. Feel free to pick these values as per your convenience and application. rth is also modifiable as per your application.\n",
    "\n",
    "We will go with the following parameter values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "731765a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "xth = 683.4\n",
    "yth = 484.4\n",
    "rth = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c26a321",
   "metadata": {},
   "source": [
    "Let us visualize this; we will plot the position with the circle around it. And then we will only consider those trajectories that lie within the circle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4afcc25d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.patches.Circle at 0x7ffa263e00d0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "circle1 = plt.Circle((xth, yth), rth, color='k', fill = False)\n",
    "ax = sns.scatterplot(data = position, x = x, y = y)\n",
    "ax.add_patch(circle1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a80fdfe",
   "metadata": {},
   "source": [
    "Great! So now we can restrict our position data to the centre of the maze. \n",
    "\n",
    "So, to do this, firstly, we will restrict our position to points within the centre. This can be done by selecting points whose distance from the centre is smaller than the radius of our circle. Time to put Pynapple to use!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9e9bf7c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time (s)\n",
       "0.000000       502.525107\n",
       "0.008333       413.003769\n",
       "0.016667       238.870630\n",
       "0.025000       503.058904\n",
       "0.033333       502.622715\n",
       "                  ...    \n",
       "1286.933333    317.914119\n",
       "1286.941667    210.607966\n",
       "1286.950000    321.120486\n",
       "1286.958333    530.946257\n",
       "1286.966667    319.043616\n",
       "Length: 154437, dtype: float64"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = np.sqrt((x - xth)**2 + (y - yth)**2)\n",
    "dist_center = nap.Tsd(t = timestamps, d = d, time_units = 's')\n",
    "\n",
    "dist_center"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "708d2218",
   "metadata": {},
   "source": [
    "Pynapple time!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "78c87b40",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time (s)\n",
       "0.041667       111.789845\n",
       "0.050000       112.349458\n",
       "0.233333       112.429611\n",
       "0.250000       198.200149\n",
       "0.316667       195.779983\n",
       "                  ...    \n",
       "1286.775000    191.168783\n",
       "1286.783333    172.081024\n",
       "1286.816667    191.184213\n",
       "1286.841667    190.762090\n",
       "1286.916667    191.413067\n",
       "Length: 84709, dtype: float64"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "within_center = dist_center.threshold(rth, 'below')\n",
    "ep = within_center.time_support\n",
    "\n",
    "within_center"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bacf588d",
   "metadata": {},
   "source": [
    "As you can see, all values in within_center are now less than our threshold value (rth). \n",
    "\n",
    "What does the time support of this look like?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "907437d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           start          end\n",
      "0       0.037500     0.054167\n",
      "1       0.229166     0.237500\n",
      "2       0.245833     0.254166\n",
      "3       0.312500     0.320834\n",
      "4       0.470834     0.479166\n",
      "..           ...          ...\n",
      "213  1286.737500  1286.745834\n",
      "214  1286.770834  1286.787500\n",
      "215  1286.812500  1286.820834\n",
      "216  1286.837500  1286.845834\n",
      "217  1286.912500  1286.920833\n",
      "\n",
      "[218 rows x 2 columns]\n"
     ]
    }
   ],
   "source": [
    "print(ep)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2566a658",
   "metadata": {},
   "source": [
    "So now we have a set of points for when the animal is within the radius of the maze. We will now seaparate these points into trials. For our purposes, we will say that trials that are only greater than 0.7s be considered. Additionally, each trial can have a maximal duration of 20s. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "51fea121",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>start</th>\n",
       "      <th>end</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.712500</td>\n",
       "      <td>4.804166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6.695834</td>\n",
       "      <td>7.912500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>9.329166</td>\n",
       "      <td>12.562500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>17.812500</td>\n",
       "      <td>18.612500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>22.937500</td>\n",
       "      <td>23.829166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1037.437500</td>\n",
       "      <td>1040.520834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1048.729166</td>\n",
       "      <td>1051.954167</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1075.654166</td>\n",
       "      <td>1078.345834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1089.045833</td>\n",
       "      <td>1091.445834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1093.912500</td>\n",
       "      <td>1275.204167</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>76 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          start          end\n",
       "0      3.712500     4.804166\n",
       "1      6.695834     7.912500\n",
       "2      9.329166    12.562500\n",
       "3     17.812500    18.612500\n",
       "4     22.937500    23.829166\n",
       "..          ...          ...\n",
       "71  1037.437500  1040.520834\n",
       "72  1048.729166  1051.954167\n",
       "73  1075.654166  1078.345834\n",
       "74  1089.045833  1091.445834\n",
       "75  1093.912500  1275.204167\n",
       "\n",
       "[76 rows x 2 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ep = ep.drop_short_intervals(0.7, time_units = 's')\n",
    "\n",
    "ep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "336a76b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>start</th>\n",
       "      <th>end</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.712500</td>\n",
       "      <td>4.804166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6.695834</td>\n",
       "      <td>7.912500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>9.329166</td>\n",
       "      <td>12.562500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>17.812500</td>\n",
       "      <td>18.612500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>22.937500</td>\n",
       "      <td>23.829166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1034.495834</td>\n",
       "      <td>1035.462500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1037.437500</td>\n",
       "      <td>1040.520834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1048.729166</td>\n",
       "      <td>1051.954167</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1075.654166</td>\n",
       "      <td>1078.345834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1089.045833</td>\n",
       "      <td>1091.445834</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>71 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          start          end\n",
       "0      3.712500     4.804166\n",
       "1      6.695834     7.912500\n",
       "2      9.329166    12.562500\n",
       "3     17.812500    18.612500\n",
       "4     22.937500    23.829166\n",
       "..          ...          ...\n",
       "66  1034.495834  1035.462500\n",
       "67  1037.437500  1040.520834\n",
       "68  1048.729166  1051.954167\n",
       "69  1075.654166  1078.345834\n",
       "70  1089.045833  1091.445834\n",
       "\n",
       "[71 rows x 2 columns]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ep = ep.drop_long_intervals(20, time_units = 's')\n",
    "ep = ep.reset_index(drop=True)\n",
    "ep"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c16d1d5e",
   "metadata": {},
   "source": [
    "Now, we will plot the trajectories corresponding to our trials: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "61819f01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7ffa241f16d0>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "sns.scatterplot(data = position, x = x, y = y)\n",
    "plt.scatter(position['x'].restrict(ep), position['y'].restrict(ep), zorder = 2, label = 'selected data')\n",
    "plt.legend(loc = 'upper right')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "436cd2b5",
   "metadata": {},
   "source": [
    "And voila! We have now obtained the trajectory of the animal within the circle of interest!\n",
    "\n",
    "We can also go one step ahead and consider only those trials where the animal goes from the departure arm to any arm \"in front\" of it. We call these \"forward trials\". To determine what constitutes a forward trial, we use the following logic: \n",
    "\n",
    "1. Any trajectory where the y-position at the end of the trial is larger than the y-position at the start of the trial.\n",
    "\n",
    "2. Any trajectory where the y-position at the end of the trial is larger than the radius of our circle.  \n",
    "\n",
    "So, first we define the change in y-position as a Pynapple Tsd: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ee969447",
   "metadata": {},
   "outputs": [],
   "source": [
    "dy = nap.Tsd(t = timestamps, d = y-yth, time_units = 's')  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bf4f989",
   "metadata": {},
   "source": [
    "Now, we will compute the variable diffy using the 2 conditions mentioned above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0d4c7dbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>start</th>\n",
       "      <th>end</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.712500</td>\n",
       "      <td>4.804166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>17.812500</td>\n",
       "      <td>18.612500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>45.120834</td>\n",
       "      <td>46.187500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>54.229167</td>\n",
       "      <td>54.954166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>78.929167</td>\n",
       "      <td>79.795834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>119.229166</td>\n",
       "      <td>120.279166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>151.295834</td>\n",
       "      <td>152.462500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>193.704166</td>\n",
       "      <td>194.629166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>220.104166</td>\n",
       "      <td>221.595834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>252.279166</td>\n",
       "      <td>253.362500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>263.329166</td>\n",
       "      <td>264.287500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>290.395834</td>\n",
       "      <td>291.920834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>298.479166</td>\n",
       "      <td>299.445834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>328.287500</td>\n",
       "      <td>329.429166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>366.204166</td>\n",
       "      <td>367.270834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>433.812500</td>\n",
       "      <td>434.770834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>464.170834</td>\n",
       "      <td>465.137500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>542.204166</td>\n",
       "      <td>543.304166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>640.045834</td>\n",
       "      <td>653.904166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>674.945834</td>\n",
       "      <td>682.537500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>746.012500</td>\n",
       "      <td>747.095834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>988.612500</td>\n",
       "      <td>989.595834</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>1019.379166</td>\n",
       "      <td>1020.329166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>1034.495834</td>\n",
       "      <td>1035.462500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>1075.654166</td>\n",
       "      <td>1078.345834</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          start          end\n",
       "0      3.712500     4.804166\n",
       "1     17.812500    18.612500\n",
       "2     45.120834    46.187500\n",
       "3     54.229167    54.954166\n",
       "4     78.929167    79.795834\n",
       "5    119.229166   120.279166\n",
       "6    151.295834   152.462500\n",
       "7    193.704166   194.629166\n",
       "8    220.104166   221.595834\n",
       "9    252.279166   253.362500\n",
       "10   263.329166   264.287500\n",
       "11   290.395834   291.920834\n",
       "12   298.479166   299.445834\n",
       "13   328.287500   329.429166\n",
       "14   366.204166   367.270834\n",
       "15   433.812500   434.770834\n",
       "16   464.170834   465.137500\n",
       "17   542.204166   543.304166\n",
       "18   640.045834   653.904166\n",
       "19   674.945834   682.537500\n",
       "20   746.012500   747.095834\n",
       "21   988.612500   989.595834\n",
       "22  1019.379166  1020.329166\n",
       "23  1034.495834  1035.462500\n",
       "24  1075.654166  1078.345834"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diffy = []\n",
    "for i in ep.index.values:\n",
    "    tmp = dy.restrict(ep.loc[[i]])\n",
    "    diffy.append(tmp.iloc[-1] - tmp.iloc[0])\n",
    "diffy = pd.Series(data = diffy)\n",
    "diffy2 = diffy[diffy > rth/2]\n",
    "    \n",
    "ep_fwd = ep.loc[diffy2.index]\n",
    "\n",
    "ep_fwd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d515d8a8",
   "metadata": {},
   "source": [
    "We are left with a subset of epochs that are forward trials. Let us now visualize this: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fee18369",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7ffa240ffd90>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "start = nap.Ts(ep_fwd['start'].values)\n",
    "ends = nap.Ts(ep_fwd['end'].values)\n",
    "\n",
    "plt.plot(dy.restrict(ep_fwd))\n",
    "plt.plot(start.value_from(dy).index.values, start.value_from(dy).values ,'x', label = 'start')\n",
    "plt.plot(ends.value_from(dy).index.values, ends.value_from(dy).values ,'*', label = 'end')\n",
    "plt.legend(loc = 'upper right')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6ede309",
   "metadata": {},
   "source": [
    "This plot shows us that the trials we have are indeed forward trials; y-position at the start of the trial is lower than the y-position at the end of the trial.\n",
    "\n",
    "Now, you can save these variables as a CSV file, to use with your other analysis scripts.\n",
    "\n",
    "I hope this tutorial was helpful. If you have any questions, comments or suggestions, please feel free to reach out to the Pynacollada Team! "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
back to top