Revision edc453189104a1f76f4b2ab230cd86f2140e3f63 authored by Anne Urai on 08 April 2021, 13:13 UTC, committed by Anne Urai on 08 April 2021, 13:13 UTC
1 parent 22583a6
Raw File
figure4a_block_probabilities.py
"""
Block structure in the biased task for an example session

@author: Anne Urai
15 January 2020
"""

from ibl_pipeline import subject, acquisition, behavior
import pandas as pd
import os
import seaborn as sns
import numpy as np
from paper_behavior_functions import (seaborn_style, figpath, EXAMPLE_MOUSE,
                                      FIGURE_HEIGHT, FIGURE_WIDTH, dj2pandas)
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# INITIALIZE A FEW THINGS
seaborn_style()
figpath = figpath()
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
cmap_dic = {20: cmap[0], 50: cmap[1], 80: cmap[2]}

# ================================= #
# SCHEMATIC OF THE BLOCKS
# ================================= #

behav = pd.DataFrame({'probability_left': [50, 50, 20, 20, 80, 80],
                      'stimulus_side': [-1, 1, -1, 1, -1, 1],
                      'prob': [50, 50, 20, 80, 80, 20]})

fig = sns.FacetGrid(behav,
                    col="probability_left", hue="probability_left", col_wrap=3,
                    col_order=[50, 20, 80], palette=cmap, sharex=True, sharey=True,
                    aspect=0.6, height=2.2)
# fig.map(sns.distplot, "stimulus_side", kde=False, norm_hist=True, bins=2, hist_kws={'rwidth':1})
fig.map(sns.barplot, "stimulus_side", "prob")
fig.set(xticks=[-0, 1], xlim=[-0.5, 1.5],
        ylim=[0, 100], yticks=[0, 50, 100], yticklabels=[])
for ax, title in zip(fig.axes.flat, ['50/50', '80/20', '20/80']):
    ax.set_title(title)
    ax.set_xticklabels(['Left', 'Right'], rotation=45)
fig.set_axis_labels('', 'Probability (%)')
fig.savefig(os.path.join(
    figpath, "figure4_panel_block_distribution.png"), dpi=600)
fig.savefig(os.path.join(figpath, "figure4_panel_block_distribution.pdf"))
plt.close('all')

# ================================= #
# EXAMPLE SESSION TIMECOURSE
# ================================= #

b = ((subject.Subject & 'subject_nickname="%s"' % EXAMPLE_MOUSE)
     * behavior.TrialSet.Trial
     * (acquisition.Session & 'task_protocol LIKE "%biased%"'
        & 'session_start_time BETWEEN "2019-08-30" and "2019-08-31"'))
bdat = b.fetch(order_by='session_start_time, trial_id',
               format='frame').reset_index()
behav = dj2pandas(bdat)
assert not behav.empty

# if 100 in df.signed_contrast.values and not 50 in df.signed_contrast.values:
behav['signed_contrast'] = behav['signed_contrast'].replace(-100, -35)
behav['signed_contrast'] = behav['signed_contrast'].replace(100, 35)

# %%
for dayidx, behavtmp in behav.groupby(['session_start_time']):

    # 1. patches to show the blocks
    fig, axes = plt.subplots(ncols=1, nrows=1, figsize=(FIGURE_WIDTH/3.2, FIGURE_HEIGHT*0.9))
    xmax = min([behavtmp.trial_id.max() + 5, 500])

    # Loop over data points; create box from errors at each point
    behavtmp['blocks'] = (behavtmp["probabilityLeft"].ne(
        behavtmp["probabilityLeft"].shift()).cumsum())

    for idx, blocknum in behavtmp.groupby('blocks'):
        left = blocknum.trial_id.min()
        width = blocknum.trial_id.max() - blocknum.trial_id.min()
        axes.add_patch(patches.Rectangle((left, 0), width, 100,
                                         fc=cmap_dic[blocknum.probabilityLeft.unique()[
                                             0]],
                                         ec='none', alpha=0.2))

    #
    # 2. actual block probabilities as grey line
    behavtmp['stim_sign'] = 100 * \
        ((np.sign(behavtmp.signed_contrast) / 2) + 0.5)
    # sns.scatterplot(x='trial_id', y='stim_sign', data=behav, color='grey',
    #                 marker='o', ax=axes, legend=False, alpha=0.5,
    #                 ec='none', linewidth=0, zorder=2)
    sns.lineplot(x='trial_id', y='stim_sign', color='black', ci=None,
                 data=behavtmp[['trial_id', 'stim_sign']].rolling(10).mean(), ax=axes)
    # %%

    # 3. ANIMAL CHOICES, rolling window
    #rightax = axes.twinx()
    behavtmp['choice_right'] = behavtmp.choice_right * 100
    sns.lineplot(x='trial_id', y='choice_right', color='firebrick', ci=None,
                 data=behavtmp[['trial_id', 'choice_right']].rolling(10).mean(), ax=axes,
                 linestyle=':')
    # rightax.set(xlim=[-5, xmax], xlabel='Trial number',
    #             ylabel='Rightwards choices (%)', ylim=[-1, 101])
    # rightax.yaxis.label.set_color("firebrick")
    # rightax.tick_params(axis='y', colors='firebrick')
    # axes.set_yticks([0, 50, 100])
    # rightax.set_yticks([0, 50, 100])
    # axes.set_title(' \n ')

    axes.set(xlim=[-5, xmax], xlabel='Trial number',
             ylabel=' ', ylim=[-1, 101])
    axes.yaxis.label.set_color("black")
    axes.tick_params(axis='y', colors='black')
    plt.tight_layout()
    fig.savefig(os.path.join(
        figpath, "figure4_panel_session_course_%s.png" % dayidx.date()), dpi=600)
    fig.savefig(os.path.join(
        figpath, "figure4_panel_session_course_%s.pdf" % dayidx.date()))
    plt.close('all')
back to top