# -*- coding: utf-8 -*-
General functions and queries for the analysis of behavioral data from the IBL task

Guido Meijer, Anne Urai, Alejandro Pan Vazquez & Miles Wells
16 Jan 2020

import seaborn as sns
import matplotlib
import os
import numpy as np
import datajoint as dj
import pandas as pd
import matplotlib.pyplot as plt

# import wrappers etc
from ibl_pipeline.utils import psychofit as psy

# Some constants
QUERY = True  # Whether to query data through DataJoint (True) or use downloaded csv files (False)
EXAMPLE_MOUSE = 'KS014'  # Mouse nickname used as an example
CUTOFF_DATE = '2020-03-23'  # Date after which sessions are excluded, previously 30th Nov
STABLE_HW_DATE = '2019-06-10'  # Date after which hardware was deemed stable

FIGURE_HEIGHT = 2  # inch
FIGURE_WIDTH = 8  # inch

EXCLUDED_SESSIONS = ['a9fb578a-9d7d-42b4-8dbc-3b419ce9f424']  # Session UUID

def group_colors():
    return sns.color_palette("Dark2", 7)

def institution_map():
    institution_map = {'UCL': 'Lab 1', 'CCU': 'Lab 2', 'CSHL': 'Lab 3', 'NYU': 'Lab 4',
                       'Princeton': 'Lab 5', 'SWC': 'Lab 6', 'Berkeley': 'Lab 7'}
    col_names = ['Lab 1', 'Lab 2', 'Lab 3', 'Lab 4', 'Lab 5', 'Lab 6', 'Lab 7', 'All labs']

    return institution_map, col_names

def seaborn_style():
    Set seaborn style for plotting figures
    sns.set(style="ticks", context="paper",
            rc={"font.size": 9,
                "axes.titlesize": 9,
                "axes.labelsize": 9,
                "lines.linewidth": 1,
                "xtick.labelsize": 7,
                "ytick.labelsize": 7,
                "savefig.transparent": True,
                "xtick.major.size": 2.5,
                "ytick.major.size": 2.5,
                "xtick.minor.size": 2,
                "ytick.minor.size": 2,
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42

def figpath():
    # Retrieve absolute path of paper-behavior dir
    repo_dir = os.path.dirname(os.path.realpath(__file__))
    # Make figure directory
    fig_dir = os.path.join(repo_dir, 'exported_figs')
    # Announce save location
    print('Figure save path: ' + fig_dir)
    # If doesn't already exist, create
    if not os.path.exists(fig_dir):
    return fig_dir

def query_subjects(as_dataframe=False, from_list=False, criterion='trained'):
    Query all mice for analysis of behavioral data

    as_dataframe:    boolean if true returns a pandas dataframe (default is False)
    from_list:       loads files from list uuids (array of uuids objects)
    criterion:       what criterion by the 30th of November - trained (a and b), biased, ephys
                     (includes ready4ephysrig, ready4delay and ready4recording).  If None,
                     all mice that completed a training session are returned, with date_trained
                     being the date of their first training session.

    from ibl_pipeline import subject, acquisition, reference
    from ibl_pipeline.analyses import behavior as behavior_analysis

    # Query all subjects with project ibl_neuropixel_brainwide_01 and get the date at which
    # they reached a given training status
    all_subjects = (subject.Subject * subject.SubjectLab * reference.Lab * subject.SubjectProject
                    & 'subject_project = "ibl_neuropixel_brainwide_01"')
    sessions = acquisition.Session * behavior_analysis.SessionTrainingStatus()
    fields = ('subject_nickname', 'sex', 'subject_birth_date', 'institution_short')

    if criterion is None:
        # Find first session of all mice; date_trained = date of first training session
        subj_query = all_subjects.aggr(
            sessions, *fields, date_trained='min(date(session_start_time))')
    else:  # date_trained = date of first session when criterion was reached
        if criterion == 'trained':
            restriction = 'training_status="trained_1a" OR training_status="trained_1b"'
        elif criterion == 'biased':
            restriction = 'task_protocol LIKE "%biased%"'
        elif criterion == 'ephys':
            restriction = 'training_status LIKE "ready%"'
            raise ValueError('criterion must be "trained", "biased" or "ephys"')
        subj_query = all_subjects.aggr(
            sessions & restriction, *fields, date_trained='min(date(session_start_time))')

    if from_list is True:
        ids = np.load('uuids_trained1.npy', allow_pickle=True)
        subj_query = subj_query & [{'subject_uuid': u_id} for u_id in ids]

    # Select subjects that reached criterion before cutoff date
    subjects = (subj_query & 'date_trained <= "%s"' % CUTOFF_DATE)
    if as_dataframe is True:
        subjects = subjects.fetch(format='frame')
        subjects = subjects.sort_values(by=['lab_name']).reset_index()

    return subjects

def query_sessions(task='all', stable=False, as_dataframe=False,
                   force_cutoff=False, criterion='biased'):
    Query all sessions for analysis of behavioral data

    task:            string indicating sessions of which task to return, can be trianing or biased
                     default is all
    stable:          boolean if True only return sessions with stable hardware, which means
                     sessions after particular date (default is False)
    as_dataframe:    boolean if True returns a pandas dataframe (default is False)
    force_cutoff:    whether the animal had to reach the criterion by the 30th of Nov. Only
                     applies to biased and ready for ephys criterion
    criterion:       what criterion by the 30th of November - trained (includes
                     a and b), biased, ready (includes ready4ephysrig, ready4delay and

    from ibl_pipeline import acquisition

    # Query sessions
    if force_cutoff is True:
        use_subjects = query_subjects(criterion=criterion).proj('subject_uuid')
        use_subjects = query_subjects().proj('subject_uuid')

    # Query all sessions or only training or biased if required
    if task == 'all':
        sessions = acquisition.Session * use_subjects & 'task_protocol NOT LIKE "%habituation%"'
    elif task == 'training':
        sessions = acquisition.Session * use_subjects & 'task_protocol LIKE "%training%"'
    elif task == 'biased':
        sessions = acquisition.Session * use_subjects & 'task_protocol LIKE "%biased%"'
    elif task == 'ephys':
        sessions = acquisition.Session * use_subjects & 'task_protocol LIKE "%ephys%"'
        raise ValueError('task must be "all", "training" or "biased"')

    # Only use sessions up until the end of December
    sessions = sessions & 'date(session_start_time) <= "%s"' % CUTOFF_DATE
    # Exclude weird sessions
    sessions = sessions & dj.Not([{'session_uuid': u_id} for u_id in EXCLUDED_SESSIONS])

    # If required only output sessions with stable hardware
    if stable is True:
        sessions = sessions & 'date(session_start_time) > "%s"' % STABLE_HW_DATE

    # Transform into pandas Dataframe if requested
    if as_dataframe is True:
        sessions = sessions.fetch(
            order_by='institution_short, subject_nickname, session_start_time', format='frame')
        sessions = sessions.reset_index()

    return sessions

def query_sessions_around_criterion(criterion='trained', days_from_criterion=(2, 0),
                                    as_dataframe=False, force_cutoff=False):
    Query all sessions for analysis of behavioral data

    criterion:              string indicating which criterion to use: trained, biased or ephys
    days_from_criterion:    two-element array which indicates which training days around the day
                            the mouse reached criterium to return, e.g. [3, 2] returns three days
                            before criterium reached up until 2 days after (default: [2, 0])
    as_dataframe:           return sessions as a pandas dataframe
    force_cutoff:           whether the animal had to reach the criterion by the 30th of Nov. Only
                            applies to biased and ready for ephys criterion

    sessions:               The sessions around the criterion day, works in conjunction with
                            any table that has session_start_time as primary key (such as
    days:                   The training days around the criterion day. Can be used in conjunction
                            with tables that have session_date as primary key (such as

    from ibl_pipeline import subject, acquisition
    from ibl_pipeline.analyses import behavior as behavior_analysis

    # Query all included subjects
    if force_cutoff is True:
        use_subjects = query_subjects(criterion=criterion).proj('subject_uuid')
        use_subjects = query_subjects().proj('subject_uuid')

    # Query per subject the date at which the criterion is reached
    sessions = acquisition.Session * behavior_analysis.SessionTrainingStatus
    if criterion == 'trained':
        restriction = 'training_status="trained_1a" OR training_status="trained_1b"'
    elif criterion == 'biased':
        restriction = 'task_protocol LIKE "%biased%" AND training_status="trained_1b"'
    elif criterion == 'ephys':
        restriction = 'training_status LIKE "ready%"'
        raise ValueError('criterion must be "trained", "biased" or "ephys"')

    subj_crit = (subject.Subject * use_subjects).aggr(
        sessions & restriction, 'subject_nickname', date_criterion='min(date(session_start_time))')

    # Query the training day at which criterion is reached
    subj_crit_day = (dj.U('subject_uuid', 'day_of_crit')
                     & (behavior_analysis.BehavioralSummaryByDate * subj_crit
                        & 'session_date=date_criterion').proj(day_of_crit='training_day'))

    # Query days around the day at which criterion is reached
    days = (behavior_analysis.BehavioralSummaryByDate * subject.Subject * subj_crit_day
            & ('training_day - day_of_crit between %d and %d'
               % (-days_from_criterion[0], days_from_criterion[1]))).proj(
                   'subject_uuid', 'subject_nickname', 'session_date')

    # Use dates to query sessions
    ses_query = acquisition.Session.aggr(
                            days, from_date='min(session_date)', to_date='max(session_date)')
    sessions = (acquisition.Session * ses_query & 'date(session_start_time) >= from_date'
                & 'date(session_start_time) <= to_date')
    # Exclude weird sessions
    sessions = sessions & dj.Not([{'session_uuid': u_id} for u_id in EXCLUDED_SESSIONS])

    # Transform to pandas dataframe if necessary
    if as_dataframe is True:
        sessions = sessions.fetch(format='frame').reset_index()
        days = days.fetch(format='frame').reset_index()

    return sessions, days

# ================================================================== #
# ================================================================== #

def fit_psychfunc(df):
    choicedat = df.groupby('signed_contrast').agg(
        {'choice': 'count', 'choice2': 'mean'}).reset_index()
    if len(choicedat) > 4: # need some minimum number of unique x-values
        pars, L = psy.mle_fit_psycho(choicedat.values.transpose(), P_model='erf_psycho_2gammas',
                                     [0, 20., 0.05, 0.05]),
                                     [choicedat['signed_contrast'].min(), 5, 0., 0.]),
                                 parmax=np.array([choicedat['signed_contrast'].max(), 40., 1, 1]))
        pars = [np.nan, np.nan, np.nan, np.nan]

    df2 = {'bias': pars[0], 'threshold': pars[1],
           'lapselow': pars[2], 'lapsehigh': pars[3]}
    df2 = pd.DataFrame(df2, index=[0])

    # # add some stuff
    # df2['easy_correct'] = df.loc[np.abs(
    #     df['signed_contrast'] > 50), 'correct'].mean(skipna=True)
    # df2['zero_contrast'] = df.loc[np.abs(
    #     df['signed_contrast'] == 0), 'choice2'].mean(skipna=True)
    # df2['median_rt'] = df['rt'].median(skipna=True)
    # df2['mean_rt'] = df['rt'].mean(skipna=True)

    # # number of trials per day
    # df4 = df.groupby(['session_start_time'])['correct'].count().reset_index()
    # df2['ntrials_perday'] = [df4['correct'].values]
    df2['ntrials'] = df['choice'].count()

    return df2

def plot_psychometric(x, y, subj, **kwargs):

    # summary stats - average psychfunc over observers
    df = pd.DataFrame({'signed_contrast': x, 'choice': y,
                       'choice2': y, 'subject_nickname': subj})
    df2 = df.groupby(['signed_contrast', 'subject_nickname']).agg(
        {'choice2': 'count', 'choice': 'mean'}).reset_index()
    df2.rename(columns={"choice2": "ntrials",
                        "choice": "fraction"}, inplace=True)
    df2 = df2.groupby(['signed_contrast']).mean().reset_index()
    df2 = df2[['signed_contrast', 'ntrials', 'fraction']]

    # only 'break' the x-axis and remove 50% contrast when 0% is present
    # print(df2.signed_contrast.unique())
    if 0. in df2.signed_contrast.values:
        brokenXaxis = True
        brokenXaxis = False

    # fit psychfunc
    pars, L = psy.mle_fit_psycho(df2.transpose().values,  # extract the data from the df
                                     [0, 20., 0.05, 0.05]),
                                     [df2['signed_contrast'].min(), 5, 0., 0.]),
                                 parmax=np.array([df2['signed_contrast'].max(), 40., 1, 1]))

    if brokenXaxis:
        # plot psychfunc
        g = sns.lineplot(np.arange(-29, 29),
                         psy.erf_psycho_2gammas(pars, np.arange(-29, 29)), **kwargs)

        # plot psychfunc: -100, +100
        sns.lineplot(np.arange(-37, -32),
                     psy.erf_psycho_2gammas(pars, np.arange(-103, -98)), **kwargs)
        sns.lineplot(np.arange(32, 37),
                     psy.erf_psycho_2gammas(pars, np.arange(98, 103)), **kwargs)

        # now break the x-axis
        # if 100 in df.signed_contrast.values and not 50 in
        # df.signed_contrast.values:
        df['signed_contrast'] = df['signed_contrast'].replace(-100, -35)
        df['signed_contrast'] = df['signed_contrast'].replace(100, 35)

        # plot psychfunc
        g = sns.lineplot(np.arange(-103, 103),
                         psy.erf_psycho_2gammas(pars, np.arange(-103, 103)), **kwargs)

    df3 = df.groupby(['signed_contrast', 'subject_nickname']).agg(
        {'choice2': 'count', 'choice': 'mean'}).reset_index()

    # plot datapoints with errorbars on top
    if df['subject_nickname'].nunique() > 1:
        # put the kwargs into a merged dict, so that overriding does not cause an error
        sns.lineplot(df3['signed_contrast'], df3['choice'],
                     'linewidth':0, 'linestyle':'None', 'mew':0.5,
                     'marker':'o', 'ci':68}, **kwargs})

    if brokenXaxis:
        g.set_xticks([-35, -25, -12.5, 0, 12.5, 25, 35])
        g.set_xticklabels(['-100', '-25', '-12.5', '0', '12.5', '25', '100'],
                          size='small', rotation=60)
        g.set_xlim([-40, 40])
        g.set_xticks([-100, -50, 0, 50, 100])
        g.set_xticklabels(['-100', '-50', '0', '50', '100'],
                          size='small', rotation=60)
        g.set_xlim([-110, 110])

    g.set_ylim([0, 1.02])
    g.set_yticks([0, 0.25, 0.5, 0.75, 1])
    g.set_yticklabels(['0', '25', '50', '75', '100'])

def plot_chronometric(x, y, subj, **kwargs):

    df = pd.DataFrame(
        {'signed_contrast': x, 'rt': y, 'subject_nickname': subj})
    df.dropna(inplace=True)  # ignore NaN RTs
    df2 = df.groupby(['signed_contrast', 'subject_nickname']
                     ).agg({'rt': 'median'}).reset_index()
    # df2 = df2.groupby(['signed_contrast']).mean().reset_index()
    df2 = df2[['signed_contrast', 'rt', 'subject_nickname']]

    # if 100 in df.signed_contrast.values and not 50 in
    # df.signed_contrast.values:
    df2['signed_contrast'] = df2['signed_contrast'].replace(-100, -35)
    df2['signed_contrast'] = df2['signed_contrast'].replace(100, 35)
    df2 = df2.loc[np.abs(df2.signed_contrast) != 50, :] # remove those

    ax = sns.lineplot(x='signed_contrast', y='rt', err_style="bars", mew=0.5,
                      ci=68, data=df2, **kwargs)

    # all the points
    if df['subject_nickname'].nunique() > 1:

    ax.set_xticks([-35, -25, -12.5, 0, 12.5, 25, 35])
    ax.set_xticklabels(['-100', '-25', '-12.5', '0', '12.5', '25', '100'],
                       size='small', rotation=45)
    ax.set_xlim([-40, 40])

    if df['signed_contrast'].min() >= 0:
        ax.set_xlim([-5, 40])
        ax.set_xticks([0, 6, 12.5, 25, 35])
        ax.set_xticklabels(['0', '6.25', '12.5', '25', '100'],
                           size='small', rotation=45)

def add_n(x, y, sj, **kwargs):

    df = pd.DataFrame({'signed_contrast': x, 'choice': y,
                       'choice2': y, 'subject_nickname': sj})

        '%d mice, %d trials' %

def dj2pandas(behav):

    # make sure all contrasts are positive
    behav['trial_stim_contrast_right'] = np.abs(
    behav['trial_stim_contrast_left'] = np.abs(

    behav['signed_contrast'] = (
        behav['trial_stim_contrast_right'] - behav['trial_stim_contrast_left']) * 100
    # behav['signed_contrast'] = behav.signed_contrast.astype(int)

    behav['trial'] = behav.trial_id  # for psychfuncfit
    val_map = {'CCW': 1, 'No Go': 0, 'CW': -1}
    behav['choice'] = behav['trial_response_choice'].map(val_map)
    behav['correct'] = np.where(
        np.sign(behav['signed_contrast']) == behav['choice'], 1, 0)
    behav.loc[behav['signed_contrast'] == 0, 'correct'] = np.NaN

    behav['choice_right'] = behav.choice.replace(
        [-1, 0, 1], [0, np.nan, 1])  # code as 0, 100 for percentages
    behav['choice2'] = behav.choice_right  # for psychfuncfit
    behav['correct_easy'] = behav.correct
    behav.loc[np.abs(behav['signed_contrast']) < 50, 'correct_easy'] = np.NaN
        columns={'trial_stim_prob_left': 'probabilityLeft'}, inplace=True)
    behav['probabilityLeft'] = behav['probabilityLeft'] * 100
    behav['probabilityLeft'] = behav.probabilityLeft.astype(int)

    # compute rt
    if 'trial_response_time' in behav.columns:
        behav['rt'] = behav['trial_response_time'] - \
        # ignore a bunch of things for missed trials
        # don't count RT if there was no response
        behav.loc[behav.choice == 0, 'rt'] = np.nan
        # don't count RT if there was no response
        behav.loc[behav.choice == 0, 'trial_feedback_type'] = np.nan

    behav['previous_choice'] = behav.choice.shift(1)
    behav.loc[behav.previous_choice == 0, 'previous_choice'] = np.nan
    behav['previous_outcome'] = behav.trial_feedback_type.shift(1)
    behav.loc[behav.previous_outcome == 0, 'previous_outcome'] = np.nan
    behav['previous_contrast'] = np.abs(behav.signed_contrast.shift(1))
    behav['previous_choice_name'] = behav['previous_choice'].map(
        {-1: 'left', 1: 'right'})
    behav['previous_outcome_name'] = behav['previous_outcome'].map(
        {-1: 'post_error', 1: 'post_correct'})
    behav['repeat'] = (behav.choice == behav.previous_choice)

    # # to more easily retrieve specific training days
    # behav['days'] = (behav['session_start_time'] -
    #                  behav['session_start_time'].min()).dt.days

    return behav

def num_star(pvalue):
    if pvalue < 0.05:
        stars = '* p < 0.05'
    if pvalue < 0.01:
        stars = '** p < 0.01'
    if pvalue < 0.001:
        stars = '*** p < 0.001'
    if pvalue < 0.0001:
        stars = '**** p < 0.0001'

    return stars
