Revision c7ddda11647093e8a0173dbd2a1986ac6239c821 authored by Daniel Millman on 22 October 2020, 20:12:19 UTC, committed by Daniel Millman on 22 October 2020, 20:12:19 UTC
1 parent f0ed474
Raw File
chisq_categorical.py
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jun  5 15:52:22 2019

@author: dan
"""

import numpy as np

def chisq_from_stim_table(stim_table,
                          columns,
                          mean_sweep_events,
                          num_shuffles=1000,
                          verbose=False):
    #  stim_table is a pandas DataFrame with len = num_sweeps
    #  columns is a list of column names that define the categories (e.g. ['Ori','Contrast'])
    #  mean_sweep_events is a numpy array with shape (num_sweeps,num_cells)
    
    sweep_categories = stim_table_to_categories(stim_table,columns,verbose=verbose)
    p_vals = compute_chi_shuffle(mean_sweep_events,sweep_categories,num_shuffles=num_shuffles)
    
    return p_vals

def compute_chi_shuffle(mean_sweep_events,
                        sweep_categories,
                        num_shuffles=1000):

    #  mean_sweep_events is a numpy array with shape (num_sweeps,num_cells)
    #  sweep_conditions is a numpy array with shape (num_sweeps)
    #       sweep_conditions gives the category label for each sweep
    
    (num_sweeps,num_cells) = np.shape(mean_sweep_events) 
    
    assert len(sweep_categories) == num_sweeps
    
    sweep_categories_dummy = make_category_dummy(sweep_categories)
    
    expected = compute_expected(mean_sweep_events,sweep_categories_dummy)
    observed = compute_observed(mean_sweep_events,sweep_categories_dummy)
    chi_actual = compute_chi(observed,expected)
    
    chi_shuffle = np.zeros((num_cells,num_shuffles))
    for ns in range(num_shuffles):
        #print 'shuffle ' + str(ns+1) + ' of ' + str(num_shuffles)
        
        shuffle_sweeps = np.random.choice(num_sweeps,size=(num_sweeps,))
        shuffle_sweep_events = mean_sweep_events[shuffle_sweeps]
        
        shuffle_expected = compute_expected(shuffle_sweep_events,sweep_categories_dummy)
        shuffle_observed = compute_observed(shuffle_sweep_events,sweep_categories_dummy)
        
        chi_shuffle[:,ns] = compute_chi(shuffle_observed,shuffle_expected)
    
    p_vals = np.mean(chi_actual.reshape(num_cells,1)<chi_shuffle,axis=1)
    
    return p_vals

def compute_chi_from_blanks(mean_sweep_events,
                        sweep_categories,
                        num_shuffles=1000):
    #  mean_sweep_events is a numpy array with shape (num_sweeps,num_cells)
    #  sweep_conditions is a numpy array with shape (num_sweeps)
    #       sweep_conditions gives the category label for each sweep
    
    (num_sweeps,num_cells) = np.shape(mean_sweep_events) 
    
    assert len(sweep_categories) == num_sweeps
    
    sweep_categories_dummy = make_category_dummy(sweep_categories)
    
    expected = compute_expected_for_blank(mean_sweep_events,sweep_categories_dummy)
    observed = compute_observed(mean_sweep_events,sweep_categories_dummy)
    chi_pos_actual, chi_neg_actual = compute_chi_rectified(observed,expected)
    
    chi_pos_shuffle = np.zeros((num_cells,num_shuffles))
    chi_neg_shuffle = np.zeros((num_cells,num_shuffles))
    for ns in range(num_shuffles):
        #print 'shuffle ' + str(ns+1) + ' of ' + str(num_shuffles)
        
        shuffle_sweeps = np.random.choice(num_sweeps,size=(num_sweeps,))
        shuffle_sweep_events = mean_sweep_events[shuffle_sweeps]
        
        shuffle_expected = compute_expected_for_blank(shuffle_sweep_events,sweep_categories_dummy)
        shuffle_observed = compute_observed(shuffle_sweep_events,sweep_categories_dummy)
        
        chi_pos, chi_neg = compute_chi_rectified(shuffle_observed,shuffle_expected)
        chi_pos_shuffle[:,ns] = chi_pos
        chi_neg_shuffle[:,ns] = chi_neg
    
    p_vals_pos = np.mean(chi_pos_actual.reshape(num_cells,1)<chi_pos_shuffle,axis=1)
    p_vals_neg = np.mean(chi_neg_actual.reshape(num_cells,1)<chi_neg_shuffle,axis=1)
    
    return p_vals_pos, p_vals_neg

def stim_table_to_categories(stim_table,
                             columns,
                             verbose=False):
    # get the categories for all sweeps with each unique combination of 
    #   parameters in 'columns' being one category
    # sweeps with non-finite values in ANY column (e.g. np.NaN) are labeled 
    #   as blank sweeps (category = -1)
    
    num_sweeps = len(stim_table)
    num_params = len(columns)
    
    unique_params = []
    options_per_column = []
    max_combination = 1
    for column in columns:
        column_params = np.unique(np.array(stim_table[column].values))
        #column_params = column_params[np.isfinite(column_params)]
        unique_params.append(column_params)
        options_per_column.append(len(column_params))
        max_combination*=len(column_params)

    category = 0
    sweep_categories = -1*np.ones((num_sweeps,))
    curr_combination = np.zeros((num_params,),dtype=np.int)
    options_per_column = np.array(options_per_column).astype(np.int)
    all_tried = False
    while not all_tried:
        
        matches_combination = np.ones((num_sweeps,),dtype=np.bool)
        for i_col,column in enumerate(columns):
            param = unique_params[i_col][curr_combination[i_col]]
            matches_param = np.array(stim_table[column].values) == param
            matches_combination *= matches_param
            
        if np.any(matches_combination):
            sweep_categories[matches_combination] = category
            if verbose:
                print('Category ' + str(category))
                for i_col,column in enumerate(columns):
                    param = unique_params[i_col][curr_combination[i_col]]
                    print(column + ': ' + str(param))
            
            category+=1
              
        #advance the combination
        curr_combination = advance_combination(curr_combination,options_per_column)
        all_tried = curr_combination[0]==options_per_column[0]
    
    if verbose:    
        blank_sweeps = sweep_categories==-1
        print('num blank: ' + str(blank_sweeps.sum()))
        
    return sweep_categories
    
def advance_combination(curr_combination,
                        options_per_column):
    
    num_cols = len(curr_combination)
    
    might_carry = True
    col = num_cols-1
    while might_carry:
        curr_combination[col] += 1
        if col==0 or curr_combination[col]<options_per_column[col]:
            might_carry = False
        else:
            curr_combination[col] = 0
            col-=1
            
    return curr_combination
    

def make_category_dummy(sweep_categories):
    #makes a dummy variable version of the sweep category list
    
    num_sweeps = len(sweep_categories)
    categories = np.sort(np.unique(sweep_categories))
    num_categories = len(categories)
    
    sweep_category_mat = np.zeros((num_sweeps,num_categories),dtype=np.bool)
    for i_cat,category in enumerate(categories):
        category_idx = np.argwhere(sweep_categories==category)[:,0]
        sweep_category_mat[category_idx,i_cat] = True
    
    return sweep_category_mat

def compute_observed(mean_sweep_events,sweep_conditions):

    (num_sweeps,num_conditions) = np.shape(sweep_conditions)
    num_cells = np.shape(mean_sweep_events)[1]   
    
    observed_mat = (mean_sweep_events.T).reshape(num_cells,num_sweeps,1) * sweep_conditions.reshape(1,num_sweeps,num_conditions)
    observed = np.sum(observed_mat,axis=1)
    
    return observed
    
def compute_expected(mean_sweep_events,sweep_conditions):   
    
    num_conditions = np.shape(sweep_conditions)[1]
    num_cells = np.shape(mean_sweep_events)[1]
    
    sweeps_per_condition = np.sum(sweep_conditions,axis=0)
    events_per_sweep = np.mean(mean_sweep_events,axis=0)
    
    expected = sweeps_per_condition.reshape(1,num_conditions) * events_per_sweep.reshape(num_cells,1) 
    
    return expected

def compute_expected_for_blank(mean_sweep_events,sweep_conditions):   
    #assumes the blank sweeps are assigned to the first category!
    
    blank_sweeps = sweep_conditions[:,0]
    
    num_conditions = np.shape(sweep_conditions)[1]
    num_cells = np.shape(mean_sweep_events)[1]
    
    sweeps_per_condition = np.sum(sweep_conditions,axis=0)
    events_per_blank_sweep = np.mean(mean_sweep_events[blank_sweeps],axis=0)
    
    expected = sweeps_per_condition.reshape(1,num_conditions) * events_per_blank_sweep.reshape(num_cells,1) 
    
    return expected

def compute_chi(observed,expected):

    chi = (observed - expected) ** 2 /expected
    chi = np.where(expected>0,chi,0.0)  
    return np.sum(chi,axis=1)

def compute_chi_rectified(observed,expected):
    
    chi = (observed - expected) ** 2 /expected
    chi = np.where(expected>0,chi,0.0) 
    
    chi_pos = np.where(observed>expected,chi,0.0).sum(axis=1)
    chi_neg = np.where(observed<expected,chi,0.0).sum(axis=1)
    
    return chi_pos, chi_neg
back to top