swh:1:snp:9afa324ac020f1c169cae2c8f220a29e1ff11375
Raw File
Tip revision: 8ee5e519da5cb90590865e9a692b96ad7e68a69e authored by David Shorten on 06 February 2022, 07:07:59 UTC
various
Tip revision: 8ee5e51
pairplots.py
import numpy as np
import h5py
import seaborn as sns
import os
import matplotlib.pyplot as plt
#import matplotlib.colors as clr
from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, HPacker, VPacker
from matplotlib import transforms
import pandas as pd
import statsmodels.api as sm
import scipy.stats
import params

MARKER_SIZE = 3
COLOUR_MARKERS = sns.color_palette(palette = 'colorblind')[0]
COLOUR_LINES = sns.color_palette(palette = 'colorblind')[1]
MARKER_LINE_WIDTH = 4

def plot_on_axis(axs, x, y, run_index):

    # Some means would have been fully masked (inactive electrodes) need to remove them
    combined_mask = x.mask | y.mask
    x = x[~combined_mask]
    y = y[~combined_mask]

    if len(x) == 0 or len(y) == 0:
        return

    axs.scatter(x, y, s = MARKER_SIZE, color = COLOUR_MARKERS)
    coef = np.polyfit(x, y, 1)
    line_fn = np.poly1d(coef)
    axs.plot(x, line_fn(x), color = COLOUR_LINES)
    rho, p_rho = scipy.stats.spearmanr(x, y)
    sig_string = ""
    sig_string_bonf = ""
    #bonferroni_factor =
    if p_rho < 0.01:
        sig_string = "**"
    elif p_rho < 0.05:
        sig_string = "*"
    bonf_fact = 19
    if run_index == 11:
        bonf_fact = 3
    if p_rho < 0.01/bonf_fact:
        sig_string_bonf = "**"
    elif p_rho < 0.05/bonf_fact:
        sig_string_bonf = "*"
    range_x = np.max(x) - np.min(x)
    range_y = np.max(y) - np.min(y)
    axs.set_xlim([np.min(x) - 0.05 * range_x, np.max(x) + 0.05 * range_x])
    axs.set_ylim([np.min(y) - 0.05 * range_y, np.max(y) + 0.05 * range_y])
    #
    #
    # # rend = axs.figure.canvas.get_renderer()
    # # text = axs.text(np.min(x) + 0.1 * range_x, np.min(y) + 0.9 * range_y,
    # #          r'$\rho =' + "{:.2f}".format(rho) + r'$' + sig_string, fontdict = {'size' : 12})
    # # text.draw(rend)
    # # ex = text.get_window_extent()
    # # t = transforms.offset_copy(text.get_transform(), x = ex.width, units = 'dots')
    # # text = axs.text(np.min(x) - 0.15 * range_x, np.min(y) + 0.9 * range_y,
    # #                                                             sig_string_bonf, color = 'red',
    # #                                                             fontdict = {'size' : 12}, transform = t)
    # # text.draw(rend)
    #
    rho_box_1 = TextArea(r'$\rho =' + "{:.2f}".format(rho) + r'$' + sig_string, textprops=dict(color="k", size=12, ha='left', va='top'))
    rho_box_2 = TextArea(sig_string_bonf, textprops=dict(color="r", size=12, ha='left', va='top'))

    ybox = HPacker(children=[rho_box_1, rho_box_2], align="top", pad=0, sep=0.25)

    anchored_ybox = AnchoredOffsetbox(loc = 'upper center', child = ybox, pad = 0, bbox_to_anchor = (0.5, 1.15),
                                      bbox_transform=axs.transAxes, frameon = False)

    axs.add_artist(anchored_ybox)

def frame_and_plot(data, name, sub_base_folder, run_index, cols, remove_zeros = False, normal_labels = True):

    plt.subplots_adjust(wspace = 0.1, hspace = 0.1)
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=12)
    plt.rc('axes', linewidth=1.5)
    plt.rc('axes', labelsize=12)

    start_index = 0
    if run_index in [2, 6, 9]:
        start_index = 1



    if run_index in [0, 1, 3, 5, 8, 10]:
        fig, axs = plt.subplots(1, 1, figsize = (2, 2))
        if run_index in [0, 1, 3]:
            plot_on_axis(axs, data[:, 1], data[:, 2], run_index)
            if normal_labels:
                axs.set_xlabel("day " + str(params.DAYS[run_index][1]) + "\n" + r'TE (nats.s$^{-1}$)')
                axs.set_ylabel("day " + str(params.DAYS[run_index][2]) + "\n" + r'TE (nats.s$^{-1}$)')
            else:
                axs.set_xlabel("day " + str(params.DAYS[run_index][1]) + "\n" + "ratio")
                axs.set_ylabel("day " + str(params.DAYS[run_index][2]) + "\n" + "ratio")
        else:
            plot_on_axis(axs, data[:, 0], data[:, 1], run_index)
            if normal_labels:
                axs.set_xlabel("day " + str(params.DAYS[run_index][0]) + "\n" + r'TE (nats.s$^{-1}$)')
                axs.set_ylabel("day " + str(params.DAYS[run_index][1]) + "\n" + r'TE (nats.s$^{-1}$)')
            else:
                axs.set_xlabel("day " + str(params.DAYS[run_index][0]) + "\n" + "ratio")
                axs.set_ylabel("day " + str(params.DAYS[run_index][1]) + "\n" + "ratio")


    else:
        subplot_grid_size = data.shape[1] - start_index - 1
        fig, axs = plt.subplots(subplot_grid_size, subplot_grid_size, figsize = (2 * subplot_grid_size, 2 * subplot_grid_size))
        for i in range(start_index, data.shape[1] - 1):
            for j in range(i + 1, data.shape[1]):

                plot_on_axis(axs[subplot_grid_size - j + start_index][i - start_index], data[:, i], data[:, j], run_index)

                if j != i + 1:
                    axs[subplot_grid_size - j + start_index][i - start_index].get_xaxis().set_ticks([])
                if i != start_index:
                    axs[subplot_grid_size - j + start_index][i - start_index].get_yaxis().set_ticks([])

        for i in range(subplot_grid_size):
            for j in range(i):
                axs[i][subplot_grid_size - j - 1].axis('off')

        for j in range(subplot_grid_size):
            if normal_labels:
                axs[subplot_grid_size - 1 - j][j].set_xlabel("day " + str(params.DAYS[run_index][start_index + j]) + "\n" + r'TE (nats.s$^{-1}$)')
            else:
                axs[subplot_grid_size - 1 - j][j].set_xlabel("day " + str(params.DAYS[run_index][start_index + j]) + "\n" + "ratio")
        for i in range(subplot_grid_size):
            if normal_labels:
                axs[i][0].set_ylabel("day " + str(params.DAYS[run_index][subplot_grid_size - i + start_index]) + "\n" + r'TE (nats.s$^{-1}$)')
            else:
                print(i, run_index, subplot_grid_size - i + start_index)
                axs[i][0].set_ylabel("day " + str(params.DAYS[run_index][subplot_grid_size - i + start_index]) + "\n" + "ratio")



    plt.savefig(params.BASE_FIGURE_FOLDER + sub_base_folder + params.CORRELATIONS_OVER_TIME_FIGURE_FOLDER + name + params.RUNS[run_index] + ".pdf", bbox_inches = 'tight')

    plt.clf()

def make_pairplots(TE, burst_positions, sub_base_folder, remove_zeros):
    for run_index in range(params.NUM_RUNS):

        this_runs_TE = TE[run_index, : , :, :]
        if run_index in [0, 1, 3, 4, 7, 8, 10, 11]:
            this_runs_TE = this_runs_TE[:-1, :, :]
        if run_index in [6, 9]:
            pass
            #this_runs_TE = this_runs_TE[1:, :, :]
        if run_index in [8, 10]:
            #this_runs_TE = this_runs_TE[1:-1, :, :]
            continue
        if run_index in [5]:
            this_runs_TE = this_runs_TE[:-2, :, :]

        #this_runs_TE = this_
        this_runs_TE = np.ma.array(this_runs_TE, mask = this_runs_TE == -1)

        meaned_out = np.transpose(np.mean(this_runs_TE, axis = 2))
        meaned_in = np.transpose(np.mean(this_runs_TE, axis = 1))
        edges = np.transpose(np.reshape(this_runs_TE, (this_runs_TE.shape[0], 3600)))

        print(run_index)
        frame_and_plot((meaned_out[:, :])/(meaned_out[:, :] + meaned_in[:, :] + 0.001), "pairs_by_ratio", sub_base_folder, run_index, params.DAYS[run_index][:], remove_zeros = remove_zeros, normal_labels = False)
        frame_and_plot(meaned_out[:, :], "pairs_by_out_TE", sub_base_folder, run_index, params.DAYS[run_index][:], remove_zeros = remove_zeros)
        frame_and_plot(meaned_in[:, :], "pairs_by_in_TE", sub_base_folder, run_index, params.DAYS[run_index][:], remove_zeros = remove_zeros)
        frame_and_plot(edges[:, :], "pairs_by_edge_TE",  sub_base_folder, run_index, params.DAYS[run_index][:], remove_zeros = remove_zeros)
back to top