Revision edc453189104a1f76f4b2ab230cd86f2140e3f63 authored by Anne Urai on 08 April 2021, 13:13 UTC, committed by Anne Urai on 08 April 2021, 13:13 UTC
1 parent 22583a6
Raw File
figure3ab_psychfuncs.py
"""
Psychometric functions of training mice, within and across labs

@author: Anne Urai
15 January 2020
"""
import seaborn as sns
import os
from os.path import join
import pandas as pd
import matplotlib.pyplot as plt
from paper_behavior_functions import (figpath, seaborn_style, group_colors, load_csv,
                                      query_sessions_around_criterion, institution_map,
                                      FIGURE_HEIGHT, FIGURE_WIDTH, QUERY, EXAMPLE_MOUSE,
                                      plot_psychometric, dj2pandas, plot_chronometric)
# import wrappers etc
from ibl_pipeline import reference, subject, behavior

# Initialize
seaborn_style()
figpath = figpath()
pal = group_colors()
institution_map, col_names = institution_map()
col_names = col_names[:-1]

# %%=============================== #
# GET DATA FROM TRAINED ANIMALS
# ================================= #

if QUERY is True:
    # query sessions
    use_sessions, use_days = query_sessions_around_criterion(criterion='trained',
                                                             days_from_criterion=[2, 0],
                                                             as_dataframe=False,
                                                             force_cutoff=True)

    # list of dicts - see https://int-brain-lab.slack.com/archives/CB13FQFK4/p1607369435116300 for explanation
    sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records')

    # Trial data to fetch
    trial_fields = ('trial_stim_contrast_left',
                    'trial_stim_contrast_right',
                    'trial_response_time',
                    'trial_stim_prob_left',
                    'trial_feedback_type',
                    'trial_stim_on_time',
                    'trial_response_choice')

    # Query trial data for sessions and subject name and lab info
    trials = (behavior.TrialSet.Trial & sess).proj(*trial_fields)

    # also get info about each subject
    subject_info = subject.Subject.proj('subject_nickname') * \
        (subject.SubjectLab * reference.Lab).proj('institution_short')

    # Fetch, join and sort data as a pandas DataFrame
    behav = dj2pandas(trials.fetch(format='frame')
                      .join(subject_info.fetch(format='frame'))
                      .sort_values(by=['institution_short', 'subject_nickname',
                                       'session_start_time', 'trial_id'])
                      .reset_index())
    behav['institution_code'] = behav.institution_short.map(institution_map)
else:
    behav = load_csv('Fig3.csv')

# print some output
print(behav.sample(n=10))

# %%=============================== #
# PSYCHOMETRIC FUNCTIONS
# ================================= #

# how many mice are there for each lab?
N = behav.groupby(['institution_code'])['subject_nickname'].nunique().to_dict()
behav['n_mice'] = behav.institution_code.map(N)
behav['institution_name'] = behav.institution_code + '\n ' + behav.n_mice.apply(str) + ' mice'

# plot one curve for each animal, one panel per lab
plt.close('all')
fig = sns.FacetGrid(behav,
                    col="institution_code", col_wrap=7, col_order=col_names,
                    sharex=True, sharey=True, hue="subject_uuid",
                    height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/7)/FIGURE_HEIGHT)
fig.map(plot_psychometric, "signed_contrast", "choice_right",
        "subject_nickname", color='gray', alpha=0.7)
fig.set_titles("{col_name}")

# overlay the example mouse
tmpdat = behav[behav['subject_nickname'].str.contains(EXAMPLE_MOUSE)]
plot_psychometric(tmpdat.signed_contrast, tmpdat.choice_right, tmpdat.subject_nickname,
                  color='black', ax=fig.axes[0], legend=False)

# add lab means on top
for axidx, ax in enumerate(fig.axes.flat):
    tmp_behav = behav.loc[behav.institution_name == behav.institution_name.unique()[axidx], :]
    plot_psychometric(tmp_behav.signed_contrast, tmp_behav.choice_right,
                      tmp_behav.institution_name, ax=ax, legend=False, color=pal[axidx], linewidth=2)
    ax.set_title(sorted(behav.institution_name.unique())[axidx],
                 color=pal[axidx])

fig.despine(trim=True)
fig.set_axis_labels("\u0394 Contrast (%)", 'Rightward choices (%)')
plt.tight_layout(w_pad=1)
fig.savefig(os.path.join(figpath, "figure3a_psychfuncs.pdf"))
fig.savefig(os.path.join(figpath, "figure3a_psychfuncs.png"), dpi=300)
print('done')

# %%

# Plot all labs
fig, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT))
for i, inst in enumerate(behav.institution_code.unique()):
    tmp_behav = behav[behav['institution_code'].str.contains(inst)]
    plot_psychometric(tmp_behav.signed_contrast, tmp_behav.choice_right,
                      tmp_behav.subject_nickname, ax=ax1, legend=False, color=pal[i])
#ax1.set_title('All labs', color='k', fontweight='bold')
ax1.set_title('All labs: %d mice' % behav['subject_nickname'].nunique())
ax1.set(xlabel='\u0394 Contrast (%)', ylabel='Rightward choices (%)')
sns.despine(trim=True)
plt.tight_layout()
fig.savefig(os.path.join(figpath, "figure3b_psychfuncs_all_labs.pdf"))
fig.savefig(os.path.join(figpath, "figure3b_psychfuncs_all_labs.png"), dpi=300)

# ================================= #
# single summary panel
# ================================= #

# Plot all labs
fig, ax1 = plt.subplots(1, 2, figsize=(8, 4))
plot_psychometric(behav.signed_contrast, behav.choice_right,
                  behav.subject_nickname, ax=ax1[0], legend=False, color='k')
ax1[0].set_title('Psychometric function', color='k', fontweight='bold')
ax1[0].set(xlabel='\u0394 Contrast (%)', ylabel='Rightward choice (%)')

plot_chronometric(behav.signed_contrast, behav.rt,
                  behav.subject_nickname, ax=ax1[1], legend=False, color='k')
ax1[1].set_title('Chronometric function', color='k', fontweight='bold')
ax1[1].set(xlabel='\u0394 Contrast (%)', ylabel='Trial duration (s)', ylim=[0, 1.4])
sns.despine(trim=True)
plt.tight_layout()
fig.savefig(os.path.join(figpath, "summary_psych_chron.pdf"))
plt.show()
back to top