Raw File
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantify the variability of the time to trained over labs.

@author: Guido Meijer, Miles Wells
16 Jan 2020
"""
from os.path import join

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns
from scipy import stats
import scikit_posthocs as sp
from paper_behavior_functions import (query_subjects, seaborn_style, institution_map,
                                      group_colors, figpath, load_csv, datapath,
                                      EXAMPLE_MOUSE, FIGURE_HEIGHT, FIGURE_WIDTH, QUERY)
from ibl_pipeline.analyses import behavior as behavior_analysis

# Settings
fig_path = figpath()
seaborn_style()
institution_map, col_names = institution_map()

if QUERY is True:
    # Query sessions
    use_subjects = query_subjects()
    ses = (behavior_analysis.BehavioralSummaryByDate * use_subjects)
    ses = (ses & 'session_date <= date_trained').fetch(format='frame').reset_index()

     # Construct dataframe
    training_time = pd.DataFrame(columns=['sessions'], data=ses.groupby('subject_nickname').size())
    ses['n_trials_date'] = ses['n_trials_date'].astype(int)
    training_time['trials'] = ses.groupby('subject_nickname').sum()['n_trials_date']
    training_time['lab'] = ses.groupby('subject_nickname')['institution_short'].apply(list).str[0]

    # Change lab name into lab number
    training_time['lab_number'] = training_time.lab.map(institution_map)
    training_time = training_time.sort_values('lab_number')
    training_time = training_time.reset_index()

else:
    data = pd.read_pickle(join(datapath(), 'Fig2af.pkl')).dropna()
    use_subjects = data['subject_nickname'].unique()  # For counting the number of subjects
    training_time = pd.DataFrame()
    for i, subject in enumerate(use_subjects):
        training_time = training_time.append(pd.DataFrame(index=[training_time.shape[0] + 1],
                                                          data={
            'subject_nickname': subject,
            'lab': data.loc[data['subject_nickname'] == subject, 'institution_short'].unique(),
            'sessions': data.loc[((data['subject_nickname'] == subject)
                                  & (data['session_date'] < data['date_trained']))].shape[0],
            'trials': data.loc[((data['subject_nickname'] == subject)
                                & (data['session_date'] < data['date_trained'])),
                               'n_trials_date'].sum()}))
    training_time['lab_number'] = training_time.lab.map(institution_map)
    training_time = training_time.sort_values('lab_number').reset_index(drop=True)

# Number of sessions to trained for example mouse
example_training_time = \
    training_time.reset_index()[training_time.reset_index()[
        'subject_nickname'].str.match(EXAMPLE_MOUSE)]['sessions']
# example_training_time = training_time.ix[EXAMPLE_MOUSE]['sessions']

#  statistics
# Test normality
_, normal = stats.normaltest(training_time['sessions'])
if normal < 0.05:
    kruskal = stats.kruskal(*[group['sessions'].values
                              for name, group in training_time.groupby('lab')])
    if kruskal[1] < 0.05:  # Proceed to posthocs
        posthoc = sp.posthoc_dunn(training_time, val_col='sessions',
                                  group_col='lab_number')
else:
    anova = stats.f_oneway(*[group['sessions'].values
                             for name, group in training_time.groupby('lab')])
    if anova[1] < 0.05:
        posthoc = sp.posthoc_tukey(training_time, val_col='sessions',
                                   group_col='lab_number')


# %% PLOT

# Set figure style and color palette
use_palette = [[0.6, 0.6, 0.6]] * len(np.unique(training_time['lab']))
use_palette = use_palette + [[1, 1, 0.2]]
lab_colors = group_colors()

# Add all mice to dataframe seperately for plotting
training_time_no_all = training_time.copy()
training_time_no_all.loc[training_time_no_all.shape[0] + 1, 'lab_number'] = 'All'
training_time_all = training_time.copy()
training_time_all['lab_number'] = 'All'
training_time_all = training_time.append(training_time_all)

# print
print(training_time_all.reset_index().groupby(['lab_number'])['subject_nickname'].nunique())

f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/3, FIGURE_HEIGHT))
sns.set_palette(lab_colors)
sns.swarmplot(y='sessions', x='lab_number', hue='lab_number', data=training_time_no_all,
              palette=lab_colors, ax=ax1, marker='.')
axbox = sns.boxplot(y='sessions', x='lab_number', data=training_time_all,
                    color='white', showfliers=False, ax=ax1)
axbox.artists[-1].set_edgecolor('black')
for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)):
    axbox.lines[j].set_color('black')
ax1.set(ylabel='Days to trained', xlabel='', ylim=[0, 60])
ax1.get_legend().set_visible(False)
# [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels())]
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40)
sns.despine(trim=True)
plt.tight_layout()
plt.savefig(join(fig_path, 'figure2g_time_to_trained.pdf'))
plt.savefig(join(fig_path, 'figure2g_time_to_trained.png'), dpi=300)


# SAME FOR TRIALS TO TRAINED
f, (ax1) = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT))
sns.set_palette(lab_colors)
sns.swarmplot(y='trials', x='lab_number', hue='lab_number', data=training_time_no_all,
              palette=lab_colors, ax=ax1, marker='.')
axbox = sns.boxplot(y='trials', x='lab_number', data=training_time_all,
                    color='white', showfliers=False, ax=ax1)
axbox.artists[-1].set_edgecolor('black')
for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)):
    axbox.lines[j].set_color('black')
ax1.set(ylabel='Trials to trained', xlabel='')
ax1.get_legend().set_visible(False)
# [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels())]
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=40)
format_fcn = ticker.FuncFormatter(lambda x, pos: '{:,.0f}'.format(x / 1e3) + 'K')
ax1.yaxis.set_major_formatter(format_fcn)
sns.despine(trim=True)
plt.tight_layout()
plt.savefig(join(fig_path, 'suppfig_trials_to_trained.pdf'))
plt.savefig(join(fig_path, 'suppfig_trials_to_trained.png'), dpi=300)


# sns.swarmplot(y='trials', x='lab_number', hue='lab_number', data=training_time_no_all,
#               palette=lab_colors, ax=ax2)
# axbox = sns.boxplot(y='trials', x='lab_number', data=training_time_all,
#                     color='white', showfliers=False, ax=ax2)
# axbox.artists[-1].set_edgecolor('black')
# for j in range(5 * (len(axbox.artists) - 1), 5 * len(axbox.artists)):
#     axbox.lines[j].set_color('black')
# ax2.set(ylabel='Trials to trained', xlabel='', ylim=[0, 50000])
# ax2.get_legend().set_visible(False)
# # [tick.set_color(lab_colors[i]) for i, tick in enumerate(ax1.get_xticklabels())]
# plt.setp(ax2.xaxis.get_majorticklabels(), rotation=40)

# Get stats in text
# Interquartile range per lab
iqtr = training_time.groupby(['lab'])[
    'sessions'].quantile(0.75) - training_time.groupby(['lab'])[
    'sessions'].quantile(0.25)

# Training time as a whole
m_train = training_time['sessions'].mean()
s_train = training_time['sessions'].std()
slowest = training_time['sessions'].max()
fastest = training_time['sessions'].min()

# Print information used in the paper
print('For mice that learned the task, the average training took %.1f ± %.1f days (s.d., '
      'n = %d), similar to the %d days of the example mouse from Lab 1 (Figure 2a, black). The '
      'fastest learner met training criteria in %d days, the slowest %d days'
      % (m_train, s_train, len(use_subjects), example_training_time, fastest, slowest))

# Training time in trials
m_train = training_time['trials'].mean() / 1000
s_train = training_time['trials'].std() / 1000
slowest = training_time['trials'].max() / 1000
fastest = training_time['trials'].min() / 1000

print('In trials, the average training took %.1fK ± %.1fK trials (s.d., '
      'n = %d), similar to the %dK trials of the example mouse from Lab 1 (Figure 2a, black). The '
      'fastest learner met training criteria in %dK trials, the slowest %dK trials.'
      % (m_train, s_train, len(use_subjects), example_training_time, fastest, slowest))
back to top