""" Psychometric functions of training mice, within and across labs @author: Anne Urai 15 January 2020 """ import seaborn as sns import os from os.path import join import pandas as pd import matplotlib.pyplot as plt from paper_behavior_functions import (figpath, seaborn_style, group_colors, load_csv, query_sessions_around_criterion, institution_map, FIGURE_HEIGHT, FIGURE_WIDTH, QUERY, EXAMPLE_MOUSE, plot_psychometric, dj2pandas, plot_chronometric) # import wrappers etc from ibl_pipeline import reference, subject, behavior # Initialize seaborn_style() figpath = figpath() pal = group_colors() institution_map, col_names = institution_map() col_names = col_names[:-1] # %%=============================== # # GET DATA FROM TRAINED ANIMALS # ================================= # if QUERY is True: # query sessions use_sessions, use_days = query_sessions_around_criterion(criterion='trained', days_from_criterion=[2, 0], as_dataframe=False, force_cutoff=True) # list of dicts - see https://int-brain-lab.slack.com/archives/CB13FQFK4/p1607369435116300 for explanation sess = use_sessions.proj('task_protocol').fetch(format='frame').reset_index().to_dict('records') # Trial data to fetch trial_fields = ('trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_time', 'trial_stim_prob_left', 'trial_feedback_type', 'trial_stim_on_time', 'trial_response_choice') # Query trial data for sessions and subject name and lab info trials = (behavior.TrialSet.Trial & sess).proj(*trial_fields) # also get info about each subject subject_info = subject.Subject.proj('subject_nickname') * \ (subject.SubjectLab * reference.Lab).proj('institution_short') # Fetch, join and sort data as a pandas DataFrame behav = dj2pandas(trials.fetch(format='frame') .join(subject_info.fetch(format='frame')) .sort_values(by=['institution_short', 'subject_nickname', 'session_start_time', 'trial_id']) .reset_index()) behav['institution_code'] = behav.institution_short.map(institution_map) else: behav = load_csv('Fig3.csv') # print some output print(behav.sample(n=10)) # %%=============================== # # PSYCHOMETRIC FUNCTIONS # ================================= # # how many mice are there for each lab? N = behav.groupby(['institution_code'])['subject_nickname'].nunique().to_dict() behav['n_mice'] = behav.institution_code.map(N) behav['institution_name'] = behav.institution_code + '\n ' + behav.n_mice.apply(str) + ' mice' # plot one curve for each animal, one panel per lab plt.close('all') fig = sns.FacetGrid(behav, col="institution_code", col_wrap=7, col_order=col_names, sharex=True, sharey=True, hue="subject_uuid", height=FIGURE_HEIGHT, aspect=(FIGURE_WIDTH/7)/FIGURE_HEIGHT) fig.map(plot_psychometric, "signed_contrast", "choice_right", "subject_nickname", color='gray', alpha=0.7) fig.set_titles("{col_name}") # overlay the example mouse tmpdat = behav[behav['subject_nickname'].str.contains(EXAMPLE_MOUSE)] plot_psychometric(tmpdat.signed_contrast, tmpdat.choice_right, tmpdat.subject_nickname, color='black', ax=fig.axes[0], legend=False) # add lab means on top for axidx, ax in enumerate(fig.axes.flat): tmp_behav = behav.loc[behav.institution_name == behav.institution_name.unique()[axidx], :] plot_psychometric(tmp_behav.signed_contrast, tmp_behav.choice_right, tmp_behav.institution_name, ax=ax, legend=False, color=pal[axidx], linewidth=2) ax.set_title(sorted(behav.institution_name.unique())[axidx], color=pal[axidx]) fig.despine(trim=True) fig.set_axis_labels("\u0394 Contrast (%)", 'Rightward choices (%)') plt.tight_layout(w_pad=1) fig.savefig(os.path.join(figpath, "figure3a_psychfuncs.pdf")) fig.savefig(os.path.join(figpath, "figure3a_psychfuncs.png"), dpi=300) print('done') # %% # Plot all labs fig, ax1 = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT)) for i, inst in enumerate(behav.institution_code.unique()): tmp_behav = behav[behav['institution_code'].str.contains(inst)] plot_psychometric(tmp_behav.signed_contrast, tmp_behav.choice_right, tmp_behav.subject_nickname, ax=ax1, legend=False, color=pal[i]) #ax1.set_title('All labs', color='k', fontweight='bold') ax1.set_title('All labs: %d mice' % behav['subject_nickname'].nunique()) ax1.set(xlabel='\u0394 Contrast (%)', ylabel='Rightward choices (%)') sns.despine(trim=True) plt.tight_layout() fig.savefig(os.path.join(figpath, "figure3b_psychfuncs_all_labs.pdf")) fig.savefig(os.path.join(figpath, "figure3b_psychfuncs_all_labs.png"), dpi=300) # ================================= # # single summary panel # ================================= # # Plot all labs fig, ax1 = plt.subplots(1, 2, figsize=(8, 4)) plot_psychometric(behav.signed_contrast, behav.choice_right, behav.subject_nickname, ax=ax1[0], legend=False, color='k') ax1[0].set_title('Psychometric function', color='k', fontweight='bold') ax1[0].set(xlabel='\u0394 Contrast (%)', ylabel='Rightward choice (%)') plot_chronometric(behav.signed_contrast, behav.rt, behav.subject_nickname, ax=ax1[1], legend=False, color='k') ax1[1].set_title('Chronometric function', color='k', fontweight='bold') ax1[1].set(xlabel='\u0394 Contrast (%)', ylabel='Trial duration (s)', ylim=[0, 1.4]) sns.despine(trim=True) plt.tight_layout() fig.savefig(os.path.join(figpath, "summary_psych_chron.pdf")) plt.show()