"""
Training progression for an example mouse
@author: Anne Urai, Gaelle Chapuis, Miles Wells
21 April 2020
"""
import os
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import seaborn as sns
import datajoint as dj
from paper_behavior_functions import seaborn_style, figpath, \
FIGURE_HEIGHT, FIGURE_WIDTH, EXAMPLE_MOUSE, dj2pandas, plot_psychometric
# import wrappers etc
from ibl_pipeline import subject, behavior, acquisition
from ibl_pipeline.analyses import behavior as behavioral_analyses
endcriteria = dj.create_virtual_module(
'SessionEndCriteriaImplemented', 'group_shared_end_criteria')
def plot_contrast_heatmap(mouse, lab, ax, xlims):
"""
This function is copied from
IBL-pipeline/prelim_analyses/behavioral_snapshots/behavior_plots.py
"""
cmap = copy.copy(plt.get_cmap('vlag'))
cmap.set_bad(color="w") # remove rectangles without data, should be white
session_date, signed_contrasts, prob_choose_right, prob_left_block = (
behavioral_analyses.BehavioralSummaryByDate.PsychResults * subject.Subject *
subject.SubjectLab & 'subject_nickname="%s"' % mouse & 'lab_name="%s"' % lab).proj(
'signed_contrasts', 'prob_choose_right', 'session_date', 'prob_left_block').fetch(
'session_date', 'signed_contrasts', 'prob_choose_right', 'prob_left_block')
if not len(session_date):
return
signed_contrasts = signed_contrasts * 100
# reshape this to a heatmap format
prob_left_block2 = signed_contrasts.copy()
for i, date in enumerate(session_date):
session_date[i] = np.repeat(date, len(signed_contrasts[i]))
prob_left_block2[i] = np.repeat(prob_left_block[i], len(signed_contrasts[i]))
result = pd.DataFrame({'session_date': np.concatenate(session_date),
'signed_contrasts': np.concatenate(signed_contrasts),
'prob_choose_right': np.concatenate(prob_choose_right),
'prob_left_block': np.concatenate(prob_left_block2)})
# only use the unbiased block for now
result = result[result.prob_left_block == 0]
result = result.round({'signed_contrasts': 2})
pp2 = result.pivot("signed_contrasts", "session_date", "prob_choose_right").sort_values(
by='signed_contrasts', ascending=False)
pp2 = pp2.reindex(sorted(result.signed_contrasts.unique()))
# evenly spaced date axis
x = pd.date_range(xlims[0], xlims[1]).to_pydatetime()
pp2 = pp2.reindex(columns=x)
pp2 = pp2.iloc[::-1] # reverse, red on top
# inset axes for colorbar, to the right of plot
axins1 = inset_axes(ax, width="5%", height="90%", loc='right',
bbox_to_anchor=(0.15, 0., 1, 1),
bbox_transform=ax.transAxes, borderpad=0,)
# now heatmap
sns.heatmap(pp2, linewidths=0, ax=ax, vmin=0, vmax=1, cmap=cmap, cbar=True,
cbar_ax=axins1, cbar_kws={'label': 'Choose right (%)', 'shrink': 0.8, 'ticks': []})
ax.set(ylabel="Contrast (%)", xlabel='')
# deal with date axis and make nice looking
ax.xaxis_date()
ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MONDAY))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))
for item in ax.get_xticklabels():
item.set_rotation(60)
# ================================= #
# INITIALIZE A FEW THINGS
# ================================= #
seaborn_style() # noqa
figpath = figpath() # noqa
plt.close('all')
# FIGURE_WIDTH = 6 # make narrower
# ================================= #
# Get lab name of example mouse
# ================================= #
lab = (subject.SubjectLab * subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) \
.fetch1('lab_name')
days = [2, 7, 10, 14]
# days = [2, 7, 10, 13] # request gaelle
# ==================================================
# CONTRAST HEATMAP
# ================================= #
plt.close('all')
fig, ax = plt.subplots(1, 2, figsize=(FIGURE_WIDTH / 2, FIGURE_HEIGHT))
ax[1].axis('off')
xlims = [pd.Timestamp('2019-08-04T00'), pd.Timestamp('2019-08-31T00')]
plot_contrast_heatmap(EXAMPLE_MOUSE, lab, ax[0], xlims)
ax[0].set(ylabel='Contrast (%)', xlabel='Training day',
xticks=[d + 1.5 for d in [2,8,11,17]], xticklabels=days,
yticklabels=['100', '50', '25', '12.5', '6.25', '0',
'-6.25', '-12.5', '-25', '-50', '-100'])
for item in ax[0].get_xticklabels():
item.set_rotation(-0)
plt.tight_layout()
fig.savefig(os.path.join(figpath, "figure1_example_contrastheatmap.pdf"))
fig.savefig(os.path.join(
figpath, "figure1_example_contrastheatmap.png"), dpi=600)
# ================================================================== #
# PSYCHOMETRIC AND CHRONOMETRIC FUNCTIONS FOR EXAMPLE 3 DAYS
# ================================================================== #
# make these a bit more narrow
b = ((subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE)
* (subject.SubjectLab & 'lab_name="%s"' % lab)
* behavioral_analyses.BehavioralSummaryByDate)
behav = b.fetch(format='frame').reset_index()
behav['training_day'] = behav.training_day - \
behav.training_day.min() + 1 # start at session 1
for didx, day in enumerate(days):
# get data for today
print(day)
thisdate = behav[behav.training_day ==
day]['session_date'].dt.strftime('%Y-%m-%d').item()
b = (subject.Subject & 'subject_nickname = "%s"' % EXAMPLE_MOUSE) \
* (subject.SubjectLab & 'lab_name="%s"' % lab) \
* (acquisition.Session.proj(session_date='date(session_start_time)') &
'session_date = "%s"' % thisdate) \
* behavior.TrialSet.Trial() \
* endcriteria.SessionEndCriteriaImplemented()
behavtmp = dj2pandas(b.fetch(format='frame').reset_index())
behavtmp['trial_start_time'] = behavtmp.trial_start_time / 60 # in minutes
# unclear how this can be empty - but if it happens, skip
if behavtmp.empty:
continue
# PSYCHOMETRIC FUNCTIONS
fig, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT*0.9))
plot_psychometric(behavtmp.signed_contrast,
behavtmp.choice_right,
behavtmp.trial_id,
ax=ax, color='k')
ax.set(xlabel="\u0394 Contrast (%)")
if didx == 0:
ax.set(ylabel="Rightward choices (%)")
else:
ax.set(ylabel=" ", yticklabels=[])
# ax.set(title='Training day %d' % (day))
sns.despine(trim=True)
plt.tight_layout()
fig.savefig(os.path.join(
figpath, "figure1_example_psychfunc_day%d.pdf" % (day)))
fig.savefig(os.path.join(
figpath, "figure1_example_psychfunc_day%d.png" % (day)), dpi=600)
# ================================================================== #
# WITHIN-TRIAL DISENGAGEMENT CRITERIA
# ================================================================== #
plt.close('all')
fig, ax = plt.subplots(2, 1, sharex=True,
figsize=(FIGURE_WIDTH/5, FIGURE_HEIGHT*1.5))
# running median overlaid
sns.lineplot(x='trial_start_time', y='rt', color='black', ci=None,
data=behavtmp[['trial_start_time', 'rt']].rolling(20).median(), ax=ax[0])
ax[0].set(xlabel="", ylabel="RT (s)", ylim=[0.1, 20])
ax[0].set_yscale("log")
# fix xlims
if didx == 0:
xlim = [0, 60]
elif didx == 1:
xlim = [0, 80]
elif didx == 2:
xlim = [0, 45]
elif didx == 3:
xlim = [0, 60]
ax[0].set(yticks=[0.1, 1, 10, 20],
yticklabels=['0.1', '1', '10', ''], xlim=xlim)
if didx == 0:
ax[0].set(ylabel="Trial duration (s)")
else:
ax[0].set(ylabel=" ", yticklabels=[])
# right y-axis with sliding performance
# from :
# https://stackoverflow.com/questions/36988123/pandas-groupby-and-rolling-apply-ignoring-nans
g1 = behavtmp[['trial_start_time', 'correct_easy']].copy()
g1['correct_easy'] = g1.correct_easy * 100
g2 = g1.fillna(0).copy()
s = g2.rolling(50).sum() / g1.rolling(50).count() # the actual computation
sns.lineplot(x='trial_start_time', y='correct_easy', color='black', ci=None,
data=s, ax=ax[1])
if day == min(days):
ax[1].set(ylabel="Performance (%)\non easy trials")
else:
ax[1].set(ylabel=" ", yticklabels=[])
ax[1].set(xlabel='Time (min)', ylim=[25, 110], yticks=[25, 50, 75, 100],
xlim=ax[0].get_xlim(), xticks=[0, 20, 40, 60, 80])
# INDICATE THE REASON AND TRIAL AT WHICH SESSION SHOULD HAVE ENDED
idx = behavtmp.trial_id == behavtmp.end_status_index.unique()[0]
end_x = behavtmp.loc[idx, 'trial_start_time'].values.item()
ax[0].axvline(x=end_x, color='darkgrey', linestyle=':')
ax[1].axvline(x=end_x, color='darkgrey', linestyle=':')
# ax2.annotate(behavtmp.end_status.unique()[0], xy=(end_x, 100), xytext=(end_x, 105),
# arrowprops={'arrowstyle': "->", 'connectionstyle': "arc3"})
print(behavtmp.end_status.unique()[0])
ax[0].set(title='Day %d: %d trials' % (day, behavtmp.shape[0]))
sns.despine(trim=True)
plt.tight_layout(h_pad=-0.05)
fig.savefig(os.path.join(
figpath, "figure1_example_disengagement_day%d.pdf" % day))
fig.savefig(os.path.join(
figpath, "figure1_example_disengagement_day%d.png" % day), dpi=600)
print(didx)
print(thisdate)