#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Decode in which lab a mouse was trained based on its behavioral metrics during the three sessions of the basic task variant in which the mouse reached 80% correct on easy trials. As a positive control, the time zone in which the mouse was trained is included in the dataset since the timezone provides geographical information. Decoding is performed using leave-one-out cross-validation. To control for the imbalance in the dataset (some labs have more mice than others) a fixed number of mice is randomly sub-sampled from each lab. This random sampling is repeated for a large number of repetitions. A shuffled nul-distribution is obtained by shuffling the lab labels and decoding again for each iteration. -------------- Parameters DECODER: Which decoder to use: 'bayes', 'forest', or 'regression' N_MICE: How many mice per lab to randomly sub-sample (must be lower than the lab with the least mice) ITERATIONS: Number of times to randomly sub-sample METRICS: List of strings indicating which behavioral metrics to include during decoding of lab membership METRICS_CONTROL: List of strings indicating which metrics to use for the positive control Guido Meijer September 3, 2020 """ import pandas as pd import numpy as np from os.path import join from paper_behavior_functions import (query_session_around_performance, institution_map, QUERY, fit_psychfunc, datapath, load_csv) from sklearn.ensemble import RandomForestClassifier from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import LogisticRegression from sklearn.model_selection import LeaveOneOut from sklearn.metrics import f1_score, confusion_matrix # Parameters DECODER = 'bayes' # bayes, forest or regression N_MICE = 8 # how many mice per lab to sub-sample ITERATIONS = 2000 # how often to decode METRICS = ['perf_easy', 'threshold', 'bias'] METRICS_CONTROL = ['perf_easy', 'threshold', 'bias', 'time_zone'] # Decoding function with n-fold cross validation def decoding(data, labels, clf): kf = LeaveOneOut() y_pred = np.empty(len(labels), dtype='