Raw File
step2c_data_quality-noiseceiling.py
#!/usr/bin/env python3

"""
@ Lina Teichmann

    INPUTS: 
    call from command line with following inputs: 
        -bids_dir

    OUTPUTS:
    Calculates and plots noise ceilings based on the 200 repeat images for all sensors and each sensor group

    NOTES:
    If it doesn't exist, the script makes a figures folder in the BIDS derivatives folder
  
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import mne,os
from scipy.stats import zscore

#*****************************#
### PARAMETERS ###
#*****************************#
n_participants              = 4
n_sessions                  = 12
n_images                    = 200
names                       = ['O','T','P','F','C']
labs                        = ['Occipital','Temporal','Parietal','Frontal','Central']
colors                      = ['mediumseagreen','steelblue','goldenrod','indianred','grey']
plt.rcParams['font.size']   = '16'
plt.rcParams['font.family'] = 'Helvetica'

#*****************************#
### HELPER FUNCTIONS ###
#*****************************#
def load_epochs(preproc_dir,all_epochs = []):
    for p in range(1,n_participants+1):
        epochs = mne.read_epochs(f'{preproc_dir}/preprocessed_P{str(p)}-epo.fif', preload=False)
        all_epochs.append(epochs)
    return all_epochs

def kknc(data: np.ndarray, n: int or None = None):
    """
    Calculate the noise ceiling reported in the NSD paper (Allen et al., 2021)
    Arguments:
        data: np.ndarray
            Should be shape (ntargets, nrepetitions, nobservations)
        n: int or None
            Number of trials averaged to calculate the noise ceiling. If None, n will be the number of repetitions.
    returns:
        nc: np.ndarray of shape (ntargets)
            Noise ceiling without considering trial averaging.
        ncav: np.ndarray of shape (ntargets)
            Noise ceiling considering all trials were averaged.
    """
    if not n:
        n = data.shape[-2]
    normalized = zscore(data, axis=-1)
    noisesd = np.sqrt(np.mean(np.var(normalized, axis=-2, ddof=1), axis=-1))
    sigsd = np.sqrt(np.clip(1 - noisesd ** 2, 0., None))
    ncsnr = sigsd / noisesd
    nc = 100 * ((ncsnr ** 2) / ((ncsnr ** 2) + (1 / n)))
    return nc

def calculate_noise_ceiling(all_epochs,all_nc = []):
    n_time = len(all_epochs[0].times)
    for p in range(n_participants):
        n_channels = len(all_epochs[p].ch_names)

        #  load data
        epochs = all_epochs[p]
        # select repetition trials only and load epoched data into memory
        epochs_rep = epochs[(epochs.metadata['trial_type']=='test')]
        epochs_rep.load_data()

        # select session data and sort based on category number
        res_mat=np.empty([n_channels,n_sessions,n_images,n_time])
        for sess in range(n_sessions):
            epochs_curr = epochs_rep[epochs_rep.metadata['session_nr']==sess+1]
            sort_order = np.argsort(epochs_curr.metadata['things_category_nr'])
            epochs_curr=epochs_curr[sort_order]
            epochs_curr = np.transpose(epochs_curr._data, (1,0,2))

            res_mat[:,sess,:,:] = epochs_curr

        # run noise ceiling
        nc = np.empty([n_channels,n_time])
        for t in range(n_time):
            dat = res_mat[:,:,:,t]
            nc[:,t] = kknc(data=dat,n=n_sessions)
        all_nc.append(nc)
    return all_nc

def make_supplementary_plot(all_epochs,fig_dir):
    plt.close('all')
    fig = plt.figure(num=2,figsize = (12,8))
    gs1 = gridspec.GridSpec(n_participants+1, len(names))
    gs1.update(wspace=0.2, hspace=0.2)
    ctf_layout = mne.find_layout(all_epochs[1].info)
    counter = 0
    for i,n in enumerate(names):
        for p in range(n_participants):
            ax = fig.add_subplot(gs1[counter])
            counter+=1
            ax.clear()
            picks_epochs = np.where([s[2]==n for s in all_epochs[p].ch_names])[0]
            picks = np.where([i[2]==n for i in ctf_layout.names])[0]
            [ax.plot(all_epochs[p].times*1000,ii,color=colors[p],label=labs[i],lw=0.1,alpha=0.2) for ii in all_nc[p][picks_epochs,:]]
            ax.plot(all_epochs[p].times*1000,np.mean(all_nc[p][picks_epochs,:],axis=0),color=colors[p],label=labs[i],lw=1.5)
            ax.set_ylim([0,100])

            if i ==0:
                ax.set_title('M' + str(p+1))
                
            if i < len(names)-1:
                plt.setp(ax.get_xticklabels(), visible=False)
            else:
                ax.set_xlabel('time (ms)')

            if p == 0:
                plt.setp(ax.get_yticklabels(), visible=True)
                ax.set_ylabel(labs[i])
            else: 
                plt.setp(ax.get_yticklabels(), visible=False)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)

            if i == 2 and p == 0:
                ax.set_ylabel('Explained Variance (%)\n' + labs[i])

        #  plot sensor locations
        ax2 = fig.add_subplot(gs1[counter])
        counter+=1
        ax2.plot(ctf_layout.pos[:,0],ctf_layout.pos[:,1],color='gainsboro',marker='.',linestyle='',markersize=3)
        ax2.plot(ctf_layout.pos[picks,0],ctf_layout.pos[picks,1],color='grey',marker='.',linestyle='',markersize=3)
        ax2.axis('equal')
        ax2.axis('off')
    fig.savefig(f'{fig_dir}/data_quality-noiseceiling_all.pdf')

def make_main_plot(all_epochs,all_nc):
    plt.close('all')
    fig = plt.figure(num=1,figsize = (12,3))
    gs1 = gridspec.GridSpec(1,len(names),wspace=0.1,)
    ctf_layout = mne.find_layout(all_epochs[1].info)

    for i,n in enumerate(names):
        ax = fig.add_subplot(gs1[i])

        # plot niose ceilings
        picks_epochs = [np.where([s[2]==n for s in all_epochs[p].ch_names])[0] for p in range(n_participants)]
        picks = np.where([i[2]==n for i in ctf_layout.names])[0]
        [ax.plot(all_epochs[p].times*1000,np.mean(all_nc[p][picks_epochs[p],:],axis=0),color=colors[p],label='M'+str(p+1),lw=2) for p in range(n_participants)]

        ax.set_ylim([0,90])
        ax.set_xlim([all_epochs[1].times[0]*1000,all_epochs[1].times[len(all_epochs[1].times)-1]*1000])
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        if i == len(names)-1:
            plt.legend(frameon=False, bbox_to_anchor=(1, 0.5))
        ax.set_xlabel('time (ms)')
        if  i ==0: 
            ax.set_ylabel('Explained variance (%)')
        else: 
            plt.setp(ax.get_yticklabels(), visible=False)

        #  plot sensor locations
        ax2 = ax.inset_axes([0.55, 0.55, 0.5, 0.5])
        ax2.plot(ctf_layout.pos[:,0],ctf_layout.pos[:,1],color='darkgrey',marker='.',linestyle='',markersize=2)
        ax2.plot(ctf_layout.pos[picks,0],ctf_layout.pos[picks,1],color='k',marker='.',linestyle='',markersize=2)
        ax2.axis('equal')
        ax2.axis('off')
        ax2.set_title(labs[i],y=0.8,fontsize=14)

    fig.savefig(f'{fig_dir}/data_quality-noiseceiling_avgd.pdf')



#*****************************#
### COMMAND LINE INPUTS ###
#*****************************#
if __name__=='__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-bids_dir",
        required=True,
        help='path to bids root',
    )

    args = parser.parse_args()
    bids_dir                    = args.bids_dir
    preproc_dir                 = f'{bids_dir}/derivatives/preprocessed/'
    sourcedata_dir              = f'{bids_dir}/sourcedata/'
    fig_dir                      = f'{bids_dir}/derivatives/figures/'
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    ####### Run ########
    all_epochs = load_epochs(preproc_dir)
    all_nc = calculate_noise_ceiling(all_epochs)
    make_supplementary_plot(all_epochs,fig_dir)
    make_main_plot(all_epochs,all_nc)

back to top