Raw File
utils.py
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from scipy.stats import norm
import csv

import params


def build_networks(TE):

    list_of_networks = []

    for run_index in range(params.NUM_RUNS):

        this_runs_networks = []
        for day_index in range(params.MAX_NUM_DAYS):

            if day_index >= (len(params.DAYS[run_index])):
                continue

            else:
                G = nx.DiGraph()

                for i in range(60):
                    if i != 14:
                        position = np.argwhere(params.ELECTRODE_POSITIONS == i)[0]
                        G.add_node(i, pos = [params.POS_DELTA * position[1], 1 - params.POS_DELTA * position[0]])

                for i in range(TE.shape[2]):
                    for j in range(TE.shape[3]):
                        if TE[run_index, day_index, i, j] != 0 and TE[run_index, day_index, i, j] != -1 and i != 14 and j != 14:
                            G.add_edge(i, j, TE = TE[run_index, day_index, i, j], zorder = TE[run_index, day_index, i, j])

                this_runs_networks.append(G)

        list_of_networks.append(this_runs_networks)

    return list_of_networks

def remove_dead_electrodes(TE, surrogates):
    TE[:, :, params.DEAD_ELECTRODES, :] = -1 * np.ones((TE.shape[0], TE.shape[1], len(params.DEAD_ELECTRODES), TE.shape[3]))
    TE[:, :, :, params.DEAD_ELECTRODES] = -1 * np.ones((TE.shape[0], TE.shape[1], TE.shape[2], len(params.DEAD_ELECTRODES)))
    surrogates[:, :, params.DEAD_ELECTRODES, :, :] = -1 * np.ones((surrogates.shape[0], surrogates.shape[1],
                                                                   len(params.DEAD_ELECTRODES), surrogates.shape[3], surrogates.shape[4]))
    surrogates[:, :, :, params.DEAD_ELECTRODES, :] = -1 * np.ones((surrogates.shape[0], surrogates.shape[1], surrogates.shape[2],
                                                                   len(params.DEAD_ELECTRODES), surrogates.shape[4]))

def set_insignificants_to_zero(TE, surrogates):
    for run_index in range(params.NUM_RUNS):
        for day_index in range(len(params.DAYS[run_index])):
            for i in range(TE.shape[2]):
                for j in range(TE.shape[3]):
                    # Check that the surrogate values are not all equal (happens in eg: dead electrode, not enough spikes)
                    if not np.all(surrogates[run_index, day_index, i, j, :] == surrogates[run_index, day_index, i, j, 0]):
                        surrogate_mean = np.mean(surrogates[run_index, day_index, i, j, :])
                        surrogate_std = np.std(surrogates[run_index, day_index, i, j, :])
                        p = 1 - norm.cdf(TE[run_index, day_index, i, j], loc = surrogate_mean, scale = surrogate_std)
                        if p > params.P_CUTOFF:
                            TE[run_index, day_index, i, j] = 0
                            surrogates[run_index, day_index, i, j, :] = np.zeros(params.NUM_SURROGATES)

def bias_correct(TE, surrogates):
    #pass
    TE[TE != -1] -= np.mean(surrogates[TE != -1, :], axis = 1)

# Consider vectorising this
def normalise_TE_by_rate(TE, spike_rates):
    rate_normalised_TE = np.zeros(TE.shape)
    for run_index in range(params.NUM_RUNS):
        for day_index in range(len(params.DAYS[run_index])):
            for j in range(params.UPPER_ELECTRODE_INDEX + 1):
                if spike_rates[run_index, day_index, j] != -1:
                    for i in range(params.UPPER_ELECTRODE_INDEX + 1):
                        rate_normalised_TE[run_index, day_index, i, j] =  TE[run_index, day_index, i, j] / (spike_rates[run_index, day_index, j])

    return rate_normalised_TE

def get_spike_rates():
    spike_rates = -1 * np.ones((params.NUM_RUNS, params.MAX_NUM_DAYS, params.UPPER_ELECTRODE_INDEX + 1))
    for run_index in range(params.NUM_RUNS):
        for day_index in range(len(params.DAYS[run_index])):
            file_name = (params.SPIKE_TRAINS_FOLDER + params.RUNS[run_index] + "/" +
                        params.ORIGINAL_SPIKE_TRAIN_FILES[run_index][day_index])
            reader = csv.reader(open(file_name, 'r'), delimiter = ',')
            spike_trains = []
            for row in reader:
                spike_trains.append(row)
            for i in range(params.UPPER_ELECTRODE_INDEX + 1):
                if(len(spike_trains[i]) > 100):
                    spike_rates[run_index, day_index, i] =  ((2.5e4 * (len(spike_trains[i]) - 1)) /
                                                             (float(spike_trains[i][-2]) - float(spike_trains[i][0])))

    return spike_rates

def get_mean_burst_positions():
    mean_burst_positions = -1 * np.ones((params.NUM_RUNS, params.MAX_NUM_DAYS, params.UPPER_ELECTRODE_INDEX + 1))
    std_dev_burst_positions = -1 * np.ones((params.NUM_RUNS, params.MAX_NUM_DAYS, params.UPPER_ELECTRODE_INDEX + 1))
    mean_burst_times = -1 * np.ones((params.NUM_RUNS, params.MAX_NUM_DAYS, params.UPPER_ELECTRODE_INDEX + 1))
    mean_burst_precedences = -1 * np.ones((params.NUM_RUNS, params.MAX_NUM_DAYS, params.UPPER_ELECTRODE_INDEX + 1,  params.UPPER_ELECTRODE_INDEX + 1))
    for run_index in range(params.NUM_RUNS):
        for day_index in range(len(params.DAYS[run_index])):
            file_name_mean_pos = (params.BURST_RECORDS_FOLDER + params.RUNS[run_index] + "/" +
                        params.ORIGINAL_SPIKE_TRAIN_FILES[run_index][day_index] + ".mean_positions")
            #file_name_std_dev_pos = (params.BURST_RECORDS_FOLDER + params.RUNS[run_index] + "/" +
            #            params.ORIGINAL_SPIKE_TRAIN_FILES[run_index][day_index] + ".std_dev_positions")
            #file_name_times = (params.BURST_RECORDS_FOLDER + params.RUNS[run_index] + "/" +
            #            params.ORIGINAL_SPIKE_TRAIN_FILES[run_index][day_index] + ".mean_start_times")
            file_name_prec = (params.BURST_RECORDS_FOLDER + params.RUNS[run_index] + "/" +
                              params.ORIGINAL_SPIKE_TRAIN_FILES[run_index][day_index] + ".mean_precedences")
            for (arr, file_name) in [
                                     (mean_burst_positions, file_name_mean_pos),
                                     #(std_dev_burst_positions, file_name_std_dev_pos),
                                     #(mean_burst_times, file_name_times),
                                     ]:
                reader = csv.reader(open(file_name, 'r'), delimiter = ',')
                these_positions = []
                for row in reader:
                    these_positions.append(float(row[0]))
                arr[run_index, day_index, :] = these_positions
            reader_prec = csv.reader(open(file_name_prec, 'r'), delimiter = ',')
            i = 0
            for row in reader_prec:
                j = 0
                for item in row:
                    mean_burst_precedences[run_index, day_index, i, j] = float(item)
                    j += 1
                i += 1

    mean_burst_positions = np.ma.array(mean_burst_positions, mask = mean_burst_positions == -1)
    std_dev_burst_positions = np.ma.array(std_dev_burst_positions, mask = mean_burst_positions == -1)
    mean_burst_times = np.ma.array(mean_burst_times, mask = mean_burst_positions == -1)
    mean_burst_precedences = np.ma.array(mean_burst_precedences, mask = mean_burst_precedences == -1)

    return mean_burst_positions, std_dev_burst_positions, mean_burst_times, mean_burst_precedences

def make_basic(multicolumn, run_group_index, is_burst_plot = False, left_pos = 0.01, special_axes_label_size = 0):

    plt.rc('font', size = params.LABEL_SIZES[run_group_index])
    plt.rc('xtick', labelsize = params.TICK_SIZES[run_group_index])
    plt.rc('ytick', labelsize = params.TICK_SIZES[run_group_index])
    plt.rc('axes', titlesize = params.TITLE_SIZES[run_group_index])
    if special_axes_label_size == 0:
        plt.rc('axes', labelsize = params.LABEL_SIZES[run_group_index])
    else:
        plt.rc('axes', labelsize = special_axes_label_size)

    num_rows = len(params.RUN_GROUPS[run_group_index])
    #if run_group_index == 2:
    if run_group_index == 2:
        num_rows = 2

    num_cols = 1
    if is_burst_plot:
        num_cols = params.MAX_NUM_DAYS_BY_GROUP[run_group_index] - params.BURST_PLOT_KILLED_COLS[run_group_index]
    elif multicolumn:
        num_cols = params.MAX_NUM_DAYS_BY_GROUP[run_group_index]

    fig, axs = plt.subplots(num_rows, num_cols, figsize = (params.FIG_WIDTH, params.FIG_HEIGHTS[run_group_index]))
    #if (not multicolumn) and (run_group_index == 2):
    if run_group_index == 2:
         fig, axs = plt.subplots(num_rows, num_cols, figsize = (params.FIG_WIDTH, params.FIG_HEIGHTS[run_group_index]),
                                 gridspec_kw = {'height_ratios':[1, 0]})

    plt.subplots_adjust(hspace = params.H_SPACES[run_group_index], wspace = params.W_SPACES[run_group_index], left = 0.2)

    if run_group_index != 2:
        for i in range(len(params.RUN_GROUPS[run_group_index])):
            fig.text(left_pos, 1 - params.CULTURE_LABELS_TOP_GAPS[run_group_index] - params.CULTURE_LABELS_GAPS[run_group_index] * i,
                     "culture " + params.RUNS[params.RUN_GROUPS[run_group_index][i]], fontdict = {'size' : params.CULTURE_LABELS_SIZES[run_group_index]})

    return fig, axs
back to top