Revision 6d7341e6ce3c6c902bc04754066a66bcc1d67fe1 authored by Gaelle Chapuis on 12 January 2021, 15:29:35 UTC, committed by Gaelle Chapuis on 12 January 2021, 15:29:35 UTC
1 parent e4d37d3
figure4a_block_probabilities.py
"""
Block structure in the biased task for an example session
@author: Anne Urai
15 January 2020
"""
from ibl_pipeline import subject, acquisition, behavior
import pandas as pd
import os
import seaborn as sns
import numpy as np
from paper_behavior_functions import (seaborn_style, figpath, EXAMPLE_MOUSE,
FIGURE_HEIGHT, FIGURE_WIDTH, dj2pandas)
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# INITIALIZE A FEW THINGS
seaborn_style()
figpath = figpath()
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
cmap_dic = {20: cmap[0], 50: cmap[1], 80: cmap[2]}
# ================================= #
# SCHEMATIC OF THE BLOCKS
# ================================= #
behav = pd.DataFrame({'probability_left': [50, 50, 20, 20, 80, 80],
'stimulus_side': [-1, 1, -1, 1, -1, 1],
'prob': [50, 50, 20, 80, 80, 20]})
fig = sns.FacetGrid(behav,
col="probability_left", hue="probability_left", col_wrap=3,
col_order=[50, 20, 80], palette=cmap, sharex=True, sharey=True,
aspect=0.6, height=2.2)
# fig.map(sns.distplot, "stimulus_side", kde=False, norm_hist=True, bins=2, hist_kws={'rwidth':1})
fig.map(sns.barplot, "stimulus_side", "prob")
fig.set(xticks=[-0, 1], xlim=[-0.5, 1.5],
ylim=[0, 100], yticks=[0, 50, 100], yticklabels=[])
for ax, title in zip(fig.axes.flat, ['50/50', '80/20', '20/80']):
ax.set_title(title)
ax.set_xticklabels(['Left', 'Right'], rotation=45)
fig.set_axis_labels('', 'Probability (%)')
fig.savefig(os.path.join(
figpath, "figure4_panel_block_distribution.png"), dpi=600)
fig.savefig(os.path.join(figpath, "figure4_panel_block_distribution.pdf"))
plt.close('all')
# ================================= #
# EXAMPLE SESSION TIMECOURSE
# ================================= #
b = ((subject.Subject & 'subject_nickname="%s"' % EXAMPLE_MOUSE)
* behavior.TrialSet.Trial
* (acquisition.Session & 'task_protocol LIKE "%biased%"'
& 'session_start_time BETWEEN "2019-08-30" and "2019-08-31"'))
bdat = b.fetch(order_by='session_start_time, trial_id',
format='frame').reset_index()
behav = dj2pandas(bdat)
assert not behav.empty
# if 100 in df.signed_contrast.values and not 50 in df.signed_contrast.values:
behav['signed_contrast'] = behav['signed_contrast'].replace(-100, -35)
behav['signed_contrast'] = behav['signed_contrast'].replace(100, 35)
# %%
for dayidx, behavtmp in behav.groupby(['session_start_time']):
# 1. patches to show the blocks
fig, axes = plt.subplots(ncols=1, nrows=1, figsize=(FIGURE_WIDTH/3.2, FIGURE_HEIGHT*0.9))
xmax = min([behavtmp.trial_id.max() + 5, 500])
# Loop over data points; create box from errors at each point
behavtmp['blocks'] = (behavtmp["probabilityLeft"].ne(
behavtmp["probabilityLeft"].shift()).cumsum())
for idx, blocknum in behavtmp.groupby('blocks'):
left = blocknum.trial_id.min()
width = blocknum.trial_id.max() - blocknum.trial_id.min()
axes.add_patch(patches.Rectangle((left, 0), width, 100,
fc=cmap_dic[blocknum.probabilityLeft.unique()[
0]],
ec='none', alpha=0.2))
#
# 2. actual block probabilities as grey line
behavtmp['stim_sign'] = 100 * \
((np.sign(behavtmp.signed_contrast) / 2) + 0.5)
# sns.scatterplot(x='trial_id', y='stim_sign', data=behav, color='grey',
# marker='o', ax=axes, legend=False, alpha=0.5,
# ec='none', linewidth=0, zorder=2)
sns.lineplot(x='trial_id', y='stim_sign', color='black', ci=None,
data=behavtmp[['trial_id', 'stim_sign']].rolling(10).mean(), ax=axes)
# %%
# 3. ANIMAL CHOICES, rolling window
#rightax = axes.twinx()
behavtmp['choice_right'] = behavtmp.choice_right * 100
sns.lineplot(x='trial_id', y='choice_right', color='firebrick', ci=None,
data=behavtmp[['trial_id', 'choice_right']].rolling(10).mean(), ax=axes,
linestyle=':')
# rightax.set(xlim=[-5, xmax], xlabel='Trial number',
# ylabel='Rightwards choices (%)', ylim=[-1, 101])
# rightax.yaxis.label.set_color("firebrick")
# rightax.tick_params(axis='y', colors='firebrick')
# axes.set_yticks([0, 50, 100])
# rightax.set_yticks([0, 50, 100])
# axes.set_title(' \n ')
axes.set(xlim=[-5, xmax], xlabel='Trial number',
ylabel=' ', ylim=[-1, 101])
axes.yaxis.label.set_color("black")
axes.tick_params(axis='y', colors='black')
plt.tight_layout()
fig.savefig(os.path.join(
figpath, "figure4_panel_session_course_%s.png" % dayidx.date()), dpi=600)
fig.savefig(os.path.join(
figpath, "figure4_panel_session_course_%s.pdf" % dayidx.date()))
plt.close('all')
Computing file changes ...