swh:1:snp:9afa324ac020f1c169cae2c8f220a29e1ff11375
Raw File
Tip revision: 8ee5e519da5cb90590865e9a692b96ad7e68a69e authored by David Shorten on 06 February 2022, 07:07:59 UTC
various
Tip revision: 8ee5e51
dist_plotting.py
import numpy as np
import h5py
import seaborn as sns
import os
import matplotlib.pyplot as plt
import matplotlib.colors as clr
from scipy import stats
from sklearn.neighbors import KernelDensity

import params
import utils

def plot_distributions(TE, sub_base_folder, run_group_index):

    plt.rc('axes', labelsize = params.LABEL_SIZES[run_group_index])
    plt.rc('font', size = params.LABEL_SIZES[run_group_index])
    plt.rc('xtick', labelsize = params.TICK_SIZES[run_group_index])
    plt.rc('ytick', labelsize = params.TICK_SIZES[run_group_index])
    plt.rc('axes', titlesize = params.TITLE_SIZES[run_group_index])

    fig_dist, axs_dist = utils.make_basic(True, run_group_index)
    #fig_dist.subplots_adjust(hspace = 0.4, wspace = 0.2, left = 0.2)
    fig_QQ, axs_QQ = utils.make_basic(True, run_group_index)
    #fig_QQ.subplots_adjust(hspace = 0.4, wspace = 0.25)
    fig_QQ_straight, axs_QQ_straight = utils.make_basic(True, run_group_index)
    #fig_QQ_straight.subplots_adjust(hspace = 0.4, wspace = 0.25)

    for run_index_in_group in range(len(params.RUN_GROUPS[run_group_index])):
        run_index = params.RUN_GROUPS[run_group_index][run_index_in_group]
        for day_index in range(params.MAX_NUM_DAYS_BY_GROUP[run_group_index]):

            this_days_TE = TE[run_index, day_index, :, :].flatten()
            # Remove missing values designated by -1
            this_days_TE = this_days_TE[this_days_TE != -1]

            this_days_TE = this_days_TE[this_days_TE != 0]
            print(run_index, day_index, len(this_days_TE), this_days_TE[:20])

            # Empty plot
            if day_index >= len(params.DAYS[run_index]) or this_days_TE.shape[0] < 5:
                axs_dist[run_index_in_group, day_index].axis('off')
                axs_QQ[run_index_in_group, day_index].axis('off')
                axs_QQ_straight[run_index_in_group, day_index].axis('off')

            else:

                dists = np.abs(this_days_TE - np.median(this_days_TE))
                med_dist = np.median(dists)
                cleaned = this_days_TE[dists/med_dist < 10]
                max_val = np.max(cleaned)

                kde = KernelDensity(kernel = 'linear', bandwidth = 0.1 * max_val).fit(cleaned.reshape(-1, 1))
                xx = np.linspace(0, max_val, num = 1000)
                bins = np.arange(0, max_val , 0.1 * max_val)
                axs_dist[run_index_in_group, day_index].hist(cleaned, density = True, bins = bins,
                        color = sns.color_palette(palette = 'colorblind')[0])
                axs_dist[run_index_in_group, day_index].plot(xx, np.exp(kde.score_samples(xx.reshape(-1, 1))),
                        linewidth = 4, color = sns.color_palette(palette = 'colorblind')[1])

                qq_y = np.sort(np.log(this_days_TE))
                qq_x = stats.norm.ppf((np.arange(len(qq_y)) + 1)/(len(qq_y) + 1), loc = np.mean(np.log(this_days_TE)),
                                      scale = np.var(np.log(this_days_TE)))

                line_points = np.linspace(min(np.min(qq_y), np.min(qq_x)), max(np.max(qq_y), np.max(qq_x)), num = 1000)

                axs_QQ[run_index_in_group, day_index].scatter(qq_x, qq_y,
                                                  color = sns.color_palette(palette = 'colorblind')[0])
                axs_QQ[run_index_in_group, day_index].plot(line_points, line_points,  linewidth = 4,
                                                      color = sns.color_palette(palette = 'colorblind')[1])
                qq_y_straight = np.sort(this_days_TE)
                qq_x_straight = stats.norm.ppf((np.arange(len(qq_y_straight))+1)/(len(qq_y_straight) + 1), loc = np.mean(this_days_TE),
                                      scale = np.var(this_days_TE))
                axs_QQ_straight[run_index_in_group, day_index].scatter(qq_x_straight, qq_y_straight,
                              color = sns.color_palette(palette = 'colorblind')[0])
                line_points = np.linspace(min(np.min(qq_y_straight), np.min(qq_x_straight)),
                                          max(np.max(qq_y_straight), np.max(qq_x_straight)), num = 1000)
                axs_QQ_straight[run_index_in_group, day_index].plot(line_points, line_points, linewidth = 4,
                                                      color = sns.color_palette(palette = 'colorblind')[1])

            for axs in [axs_dist, axs_QQ, axs_QQ_straight]:
                axs[run_index_in_group, day_index].set_xlabel("")
                axs[run_index_in_group, day_index].set_ylabel("")
                if (day_index < len(params.DAYS[run_index]) and
                   ((run_index_in_group, day_index) != (len(params.RUN_GROUPS[run_group_index]) - 1, 0))):
                    axs[run_index_in_group, day_index].set_title("day " + str(params.DAYS[run_index][day_index]), y = -0.3)

            units_col = 0
            if run_group_index == 0:
                units_col = 1
            axs_dist[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_title("")
            axs_dist[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_xlabel(r'TE (nats.s$^{-1}$)' + "\n day 10")
            axs_dist[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_ylabel("frequency")
            axs_QQ[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_title("")
            axs_QQ[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_xlabel("Inverse of \n normal CDF \n day 10")
            axs_QQ[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_ylabel(r'log TE')
            axs_QQ_straight[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_title("")
            axs_QQ_straight[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_xlabel("Inverse of \n normal CDF, day 10")
            axs_QQ_straight[len(params.RUN_GROUPS[run_group_index]) - 1, units_col].set_ylabel(r'TE (nats.s$^{-1}$)')

    fig_dist.savefig(params.BASE_FIGURE_FOLDER + sub_base_folder + params.DISTRIBUTIONS_FIGURE_FOLDER
                     + params.RUN_GROUP_PREFIXES[run_group_index] + "distributions_without_zeros.pdf")
    fig_QQ.savefig(params.BASE_FIGURE_FOLDER + sub_base_folder + params.DISTRIBUTIONS_FIGURE_FOLDER
                   + params.RUN_GROUP_PREFIXES[run_group_index] + "QQ_without_zeros.pdf")
    fig_QQ_straight.savefig(params.BASE_FIGURE_FOLDER + sub_base_folder + params.DISTRIBUTIONS_FIGURE_FOLDER
                            + params.RUN_GROUP_PREFIXES[run_group_index] + "QQ_straight_normal_without_zeros.pdf")
    plt.clf()
back to top