Revision 82bf6bcfb7b0868922bbe8c5bec8ef25ec6b39ad authored by Anne Urai on 05 September 2020, 16:52:26 UTC, committed by Anne Urai on 05 September 2020, 16:52:26 UTC
1 parent 373ae2d
figure1def_training.py
"""
Training progression for an example mouse
@author: Anne Urai, Gaelle Chapuis, Miles Wells
21 April 2020
"""
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import datajoint as dj
from paper_behavior_functions import seaborn_style, figpath, \
FIGURE_HEIGHT, FIGURE_WIDTH, EXAMPLE_MOUSE, dj2pandas, plot_psychometric
# import wrappers etc
from ibl_pipeline import subject, behavior, acquisition
from ibl_pipeline.analyses import behavior as behavioral_analyses
endcriteria = dj.create_virtual_module(
'SessionEndCriteriaImplemented', 'group_shared_end_criteria')
# grab some plotting functions from datajoint
# (this is a tricky dependency, as is it can not be run in a python shell, it makes the whole file
# need to run as an executable eg. >>> python figure1_training.py in windows command prompt)
# import sys
# sys.path.append(os.path.join(os.path.dirname(__file__),
# '../IBL-pipeline/prelim_analyses/behavioral_snapshots/'))
# import ibl_pipeline.prelim_analyses.behavioral_snapshots.behavior_plots # noqa
# this only works if conda develop ./IBL-pipeline/prelim_analyses/behavioral_snapshots/ has been added to iblenv
import behavior_plots
# ================================= #
# INITIALIZE A FEW THINGS
# ================================= #
seaborn_style() # noqa
figpath = figpath() # noqa
plt.close('all')
# FIGURE_WIDTH = 6 # make narrower
# ================================= #
# Get lab name of example mouse
# ================================= #
lab = (subject.SubjectLab * subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) \
.fetch1('lab_name')
days = [2, 7, 10, 14]
# ==================================================
# CONTRAST HEATMAP
# ================================= #
plt.close('all')
fig, ax = plt.subplots(1, 2, figsize=(FIGURE_WIDTH / 2, FIGURE_HEIGHT))
ax[1].axis('off')
xlims = [pd.Timestamp('2019-08-04T00'), pd.Timestamp('2019-08-31T00')]
behavior_plots.plot_contrast_heatmap(EXAMPLE_MOUSE, lab, ax[0], xlims)
ax[0].set(ylabel='Contrast (%)', xlabel='Training day',
xticks=[d + 1.5 for d in days], xticklabels=days,
yticklabels=['100', '50', '25', '12.5', '6.25', '0',
'-6.25', '-12.5', '-25', '-50', '-100'])
for item in ax[0].get_xticklabels():
item.set_rotation(-0)
plt.tight_layout()
fig.savefig(os.path.join(figpath, "figure1_example_contrastheatmap.pdf"))
fig.savefig(os.path.join(
figpath, "figure1_example_contrastheatmap.png"), dpi=600)
# ================================================================== #
# PSYCHOMETRIC AND CHRONOMETRIC FUNCTIONS FOR EXAMPLE 3 DAYS
# ================================================================== #
# make these a bit more narrow
b = ((subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE)
* (subject.SubjectLab & 'lab_name="%s"' % lab)
* behavioral_analyses.BehavioralSummaryByDate)
behav = b.fetch(format='frame').reset_index()
behav['training_day'] = behav.training_day - \
behav.training_day.min() + 1 # start at session 1
for didx, day in enumerate(days):
# get data for today
print(day)
thisdate = behav[behav.training_day ==
day]['session_date'].dt.strftime('%Y-%m-%d').item()
b = (subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) \
* (subject.SubjectLab & 'lab_name="%s"' % lab) \
* (acquisition.Session.proj(session_date='date(session_start_time)') &
'session_date = "%s"' % thisdate) \
* behavior.TrialSet.Trial() \
* endcriteria.SessionEndCriteriaImplemented()
behavtmp = dj2pandas(b.fetch(format='frame').reset_index())
behavtmp['trial_start_time'] = behavtmp.trial_start_time / 60 # in minutes
# unclear how this can be empty - but if it happens, skip
if behavtmp.empty:
continue
# PSYCHOMETRIC FUNCTIONS
fig, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT*0.9))
plot_psychometric(behavtmp.signed_contrast,
behavtmp.choice_right,
behavtmp.trial_id,
ax=ax, color='k')
ax.set(xlabel="Contrast (%)")
if didx == 0:
ax.set(ylabel="Rightward choices (%)")
else:
ax.set(ylabel=" ", yticklabels=[])
# ax.set(title='Training day %d' % (day))
sns.despine(trim=True)
plt.tight_layout()
fig.savefig(os.path.join(
figpath, "figure1_example_psychfunc_day%d.pdf" % (day)))
fig.savefig(os.path.join(
figpath, "figure1_example_psychfunc_day%d.png" % (day)), dpi=600)
# ================================================================== #
# WITHIN-TRIAL DISENGAGEMENT CRITERIA
# ================================================================== #
plt.close('all')
fig, ax = plt.subplots(2, 1, sharex=True,
figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT*1.5))
# running median overlaid
sns.lineplot(x='trial_start_time', y='rt', color='black', ci=None,
data=behavtmp[['trial_start_time', 'rt']].rolling(20).median(), ax=ax[0])
ax[0].set(xlabel="", ylabel="RT (s)", ylim=[0.1, 20])
ax[0].set_yscale("log")
# fix xlims
if didx == 0:
xlim = [0, 60]
elif didx == 1:
xlim = [0, 80]
elif didx == 2:
xlim = [0, 45]
elif didx == 3:
xlim = [0, 60]
ax[0].set(yticks=[0.1, 1, 10, 20],
yticklabels=['0.1', '1', '10', ''], xlim=xlim)
if didx == 0:
ax[0].set(ylabel="Trial duration (s)")
else:
ax[0].set(ylabel=" ", yticklabels=[])
# right y-axis with sliding performance
# from :
# https://stackoverflow.com/questions/36988123/pandas-groupby-and-rolling-apply-ignoring-nans
g1 = behavtmp[['trial_start_time', 'correct_easy']].copy()
g1['correct_easy'] = g1.correct_easy * 100
g2 = g1.fillna(0).copy()
s = g2.rolling(50).sum() / g1.rolling(50).count() # the actual computation
sns.lineplot(x='trial_start_time', y='correct_easy', color='black', ci=None,
data=s, ax=ax[1])
if day == min(days):
ax[1].set(ylabel="Accuracy (%)")
else:
ax[1].set(ylabel=" ", yticklabels=[])
ax[1].set(xlabel='Time (min)', ylim=[25, 110], yticks=[25, 50, 75, 100],
xlim=ax[0].get_xlim(), xticks=[0, 20, 40, 60, 80])
# INDICATE THE REASON AND TRIAL AT WHICH SESSION SHOULD HAVE ENDED
idx = behavtmp.trial_id == behavtmp.end_status_index.unique()[0]
end_x = behavtmp.loc[idx, 'trial_start_time'].values.item()
ax[0].axvline(x=end_x, color='darkgrey', linestyle=':')
ax[1].axvline(x=end_x, color='darkgrey', linestyle=':')
# ax2.annotate(behavtmp.end_status.unique()[0], xy=(end_x, 100), xytext=(end_x, 105),
# arrowprops={'arrowstyle': "->", 'connectionstyle': "arc3"})
print(behavtmp.end_status.unique()[0])
ax[0].set(title='Day %d: %d trials' % (day, behavtmp.shape[0]))
sns.despine(trim=True)
plt.tight_layout(h_pad=-0.05)
fig.savefig(os.path.join(
figpath, "figure1_example_disengagement_day%d.pdf" % day))
fig.savefig(os.path.join(
figpath, "figure1_example_disengagement_day%d.png" % day), dpi=600)
print(didx)
print(thisdate)
Computing file changes ...