Revision 427ee1094ec39a632fa7030aa7563cb16cf14201 authored by Anne Urai on 17 September 2020, 15:05:21 UTC, committed by Anne Urai on 17 September 2020, 15:05:21 UTC
1 parent fe5baee
Raw File
suppfig_choicevariability_withinacrosslabs.py
"""
Choice variability across contrasts within each lab and across labs

@author: Anne Urai
2 September 2020
"""

import numpy as np
import seaborn as sns
import os
from os.path import join
import pandas as pd
import matplotlib.pyplot as plt
from paper_behavior_functions import (figpath, seaborn_style, group_colors,
                                      query_sessions_around_criterion, institution_map,
                                      FIGURE_HEIGHT, FIGURE_WIDTH, QUERY,
                                      dj2pandas)
# import wrappers etc
from ibl_pipeline import reference, subject, behavior
from sklearn.utils import shuffle

# progress bar
from tqdm.auto import tqdm
tqdm.pandas(desc="computing")

# Initialize
seaborn_style()
figpath = figpath()
pal = group_colors()
institution_map, col_names = institution_map()
col_names = col_names[:-1]
nshuf = 10000

# DEFINE ACROSS-MOUSE CHOICE VARIABILITY
def choice_variability(df):

    # FIRST, GROUP BY CONTRAST AND MOUSE
    choicedat = df.pivot_table(values='choice2', columns='subject_nickname', index='signed_contrast')

    # COMPUTE THE ACROSS-MOUSE VARIANCE FOR EACH CONTRAST
    # AVERAGE VARIANCE ACROSS CONTRASTS
    choice_var = 1 / choicedat.var(axis=1).mean()

    return choice_var


# FULL TASK, ALSO INCORPORATE PROBABILITYLEFT
def choice_variability_full(df):

    # FIRST, GROUP BY CONTRAST AND MOUSE
    choicedat = df.pivot_table(values='choice2',
                               columns='subject_nickname',
                               index=['signed_contrast', 'probabilityLeft'])

    # COMPUTE THE ACROSS-MOUSE VARIANCE FOR EACH CONTRAST
    # AVERAGE VARIANCE ACROSS CONTRASTS
    choice_var = 1 / choicedat.var(axis=1).mean()

    return choice_var

# ================================= #
# GET DATA FROM TRAINED ANIMALS
# ================================= #

# query sessions - BASIC TASK
use_sessions, _ = query_sessions_around_criterion(criterion='trained',
                                                  days_from_criterion=[0, 2],
                                                  as_dataframe=False,
                                                  force_cutoff=True)

trial_fields = ('trial_stim_contrast_left', 'trial_stim_contrast_right',
                'trial_response_time', 'trial_stim_prob_left',
                'trial_feedback_type', 'trial_stim_on_time', 'trial_response_choice')

# query trial data for sessions and subject name and lab info
trials = use_sessions.proj('task_protocol') * behavior.TrialSet.Trial.proj(*trial_fields)
subject_info = subject.Subject.proj('subject_nickname') * \
               (subject.SubjectLab * reference.Lab).proj('institution_short')

# Fetch, join and sort data as a pandas DataFrame
behav = dj2pandas(trials.fetch(format='frame')
                         .join(subject_info.fetch(format='frame'))
                         .sort_values(by=['institution_short', 'subject_nickname',
                                          'session_start_time', 'trial_id'])
                         .reset_index())

behav['institution_code'] = behav.institution_short.map(institution_map)
# split the two types of task protocols (remove the pybpod version number)
behav.drop(behav[behav['signed_contrast'].isin([50., -50])].index, inplace=True)
df_basic = behav.groupby(['institution_code', 'subject_nickname', 'signed_contrast']).agg(
    {'choice2': 'mean'}).reset_index().copy()

# query sessions - FULL TASK
use_sessions, _ = query_sessions_around_criterion(criterion='ephys',
                                                  days_from_criterion=[2, 0],
                                                  force_cutoff=True)
use_sessions = use_sessions & 'task_protocol LIKE "%biased%"'  # only get biased sessions

# restrict by list of dicts with uuids for these sessions
b = (use_sessions * subject.Subject * subject.SubjectLab * reference.Lab
     * behavior.TrialSet.Trial)

# reduce the size of the fetch
b2 = b.proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid',
            'trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_choice',
            'task_protocol', 'trial_stim_prob_left', 'trial_feedback_type',
            'trial_response_time', 'trial_stim_on_time')

# construct pandas dataframe
bdat = b2.fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id',
                format='frame').reset_index()
behav = dj2pandas(bdat)
behav['institution_code'] = behav.institution_short.map(institution_map)

# COMPUTE AVERAGE CHOICE BEHAVIOR FOR EACH SJ
df_full = behav.groupby(['institution_code', 'probabilityLeft',
                    'subject_nickname', 'signed_contrast']).agg(
    {'choice2': 'mean'}).reset_index().copy()

# ================================================================== #
# COMPUTE CHOICE VARIABILITY FOR EACH LAB - BASIC TASK
# ================================================================== #

basic_perlab = df_basic.groupby(['institution_code']).progress_apply(choice_variability).reset_index()
basic_perlab['x'] = 0
basic_perlab['choice_var'] = basic_perlab[0]


# SAME, BUT ON SHUFFLED DATA
basic_perlab_shuffled = []
# make a list of all the institution codes for all subjects (preserving their frequency)
institution_codes = [gr.institution_code.unique()[0] for _, gr in df_basic.groupby(['subject_nickname'])]
for s in tqdm(range(nshuf)):

    # use scikit learn to shuffle lab labels
    shuffled_labs = shuffle(institution_codes)
    new_df = []
    # keep all choices of one mouse together - only reassign labs!
    for idx, g in enumerate(df_basic.groupby(['subject_nickname'])):
        g[1]['new_lab'] = shuffled_labs[idx]
        new_df.append(g[1])
    df_basic = pd.concat(new_df)

    assert (all(df_basic.groupby(['institution_code'])['new_lab'].nunique() > 1))
    basic_perlab_shuffled.append(df_basic.groupby(['new_lab']).apply(choice_variability).mean())

# PLOT
f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH / 5, FIGURE_HEIGHT))
sns.swarmplot(data = basic_perlab,
              x ='x', y='choice_var', hue_order=col_names,
              hue = 'institution_code',
              palette = pal, marker='.', ax=ax1, zorder=0)
ax1.plot(0, basic_perlab['choice_var'].mean(),
             color='black', linewidth=0, marker='_', markersize=13)
ax1.get_legend().set_visible(False)
# then, shuffled distribution next to it
sns.violinplot(x=np.concatenate((np.zeros(nshuf), np.ones(nshuf))),
              y=np.concatenate((np.empty((nshuf))*np.nan, basic_perlab_shuffled)),
               palette=[[1,1,1]], ax=ax1)
ax1.set(ylabel='Within-lab choice consistency', xlabel='', ylim=[0, 120])
ax1.set_xticklabels(['Data', 'Shuffle'], ha='center')
plt.tight_layout()
sns.despine(trim=True)
f.savefig(os.path.join(figpath, "across_mouse_var_basic.pdf"))

# WHAT IS THE P-VALUE COMPARED TO THE  NULL DISTRIBUTION?
pval = np.mean(basic_perlab_shuffled > basic_perlab['choice_var'].mean())
print('Basic task, choice consistency: p-value = %.4f'%pval)

# ================================================================== #
# COMPUTE CHOICE VARIABILITY FOR EACH LAB - full TASK
# ================================================================== #

full_perlab = df_full.groupby(['institution_code']).progress_apply(choice_variability_full).reset_index()
full_perlab['x'] = 0
full_perlab['choice_var'] = full_perlab[0]

# SAME, BUT ON SHUFFLED DATA
full_perlab_shuffled = []
# make a list of all the institution codes for all subjects (preserving their frequency)
institution_codes = [gr.institution_code.unique()[0] for _, gr in df_full.groupby(['subject_nickname'])]
for s in tqdm(range(nshuf)):

    # use scikit learn to shuffle lab labels
    shuffled_labs = shuffle(institution_codes)
    new_df = []
    # keep all choices of one mouse together - only reassign labs!
    for idx, g in enumerate(df_full.groupby(['subject_nickname'])):
        g[1]['new_lab'] = shuffled_labs[idx]
        new_df.append(g[1])
    df_full = pd.concat(new_df)

    assert (all(df_full.groupby(['institution_code'])['new_lab'].nunique() > 1))
    full_perlab_shuffled.append(df_full.groupby(['new_lab']).apply(choice_variability_full).mean())


# PLOT
f, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH / 5, FIGURE_HEIGHT))
sns.swarmplot(data = full_perlab,
              x ='x', y='choice_var', hue_order=col_names,
              hue = 'institution_code',
              palette = pal, marker='.', ax=ax1, zorder=0)
ax1.plot(0, full_perlab['choice_var'].mean(),
             color='black', linewidth=0, marker='_', markersize=13)
ax1.get_legend().set_visible(False)
# then, shuffled distribution next to it
sns.violinplot(x=np.concatenate((np.zeros(nshuf), np.ones(nshuf))),
              y=np.concatenate((np.empty((nshuf))*np.nan, full_perlab_shuffled)),
               palette=[[1,1,1]], ax=ax1)
ax1.set(ylabel=' ', xlabel='', ylim=[0, 120])
ax1.set_xticklabels(['Data', 'Shuffle'], ha='center')
plt.tight_layout()
sns.despine(trim=True)
f.savefig(os.path.join(figpath, "across_mouse_var_full.pdf"))

# WHAT IS THE P-VALUE COMPARED TO THE  NULL DISTRIBUTION?
pval = np.mean(full_perlab_shuffled > full_perlab['choice_var'].mean())
print('Full task, choice consistency: p-value = %.4f'%pval)



back to top