step2c_data_quality-noiseceiling.py
#!/usr/bin/env python3
"""
@ Lina Teichmann
INPUTS:
call from command line with following inputs:
-bids_dir
OUTPUTS:
Calculates and plots noise ceilings based on the 200 repeat images for all sensors and each sensor group
NOTES:
If it doesn't exist, the script makes a figures folder in the BIDS derivatives folder
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import mne,os
from scipy.stats import zscore
#*****************************#
### PARAMETERS ###
#*****************************#
n_participants = 4
n_sessions = 12
n_images = 200
names = ['O','T','P','F','C']
labs = ['Occipital','Temporal','Parietal','Frontal','Central']
colors = ['mediumseagreen','steelblue','goldenrod','indianred','grey']
plt.rcParams['font.size'] = '16'
plt.rcParams['font.family'] = 'Helvetica'
#*****************************#
### HELPER FUNCTIONS ###
#*****************************#
def load_epochs(preproc_dir,all_epochs = []):
for p in range(1,n_participants+1):
epochs = mne.read_epochs(f'{preproc_dir}/preprocessed_P{str(p)}-epo.fif', preload=False)
all_epochs.append(epochs)
return all_epochs
def kknc(data: np.ndarray, n: int or None = None):
"""
Calculate the noise ceiling reported in the NSD paper (Allen et al., 2021)
Arguments:
data: np.ndarray
Should be shape (ntargets, nrepetitions, nobservations)
n: int or None
Number of trials averaged to calculate the noise ceiling. If None, n will be the number of repetitions.
returns:
nc: np.ndarray of shape (ntargets)
Noise ceiling without considering trial averaging.
ncav: np.ndarray of shape (ntargets)
Noise ceiling considering all trials were averaged.
"""
if not n:
n = data.shape[-2]
normalized = zscore(data, axis=-1)
noisesd = np.sqrt(np.mean(np.var(normalized, axis=-2, ddof=1), axis=-1))
sigsd = np.sqrt(np.clip(1 - noisesd ** 2, 0., None))
ncsnr = sigsd / noisesd
nc = 100 * ((ncsnr ** 2) / ((ncsnr ** 2) + (1 / n)))
return nc
def calculate_noise_ceiling(all_epochs,all_nc = []):
n_time = len(all_epochs[0].times)
for p in range(n_participants):
n_channels = len(all_epochs[p].ch_names)
# load data
epochs = all_epochs[p]
# select repetition trials only and load epoched data into memory
epochs_rep = epochs[(epochs.metadata['trial_type']=='test')]
epochs_rep.load_data()
# select session data and sort based on category number
res_mat=np.empty([n_channels,n_sessions,n_images,n_time])
for sess in range(n_sessions):
epochs_curr = epochs_rep[epochs_rep.metadata['session_nr']==sess+1]
sort_order = np.argsort(epochs_curr.metadata['things_category_nr'])
epochs_curr=epochs_curr[sort_order]
epochs_curr = np.transpose(epochs_curr._data, (1,0,2))
res_mat[:,sess,:,:] = epochs_curr
# run noise ceiling
nc = np.empty([n_channels,n_time])
for t in range(n_time):
dat = res_mat[:,:,:,t]
nc[:,t] = kknc(data=dat,n=n_sessions)
all_nc.append(nc)
return all_nc
def make_supplementary_plot(all_epochs,fig_dir):
plt.close('all')
fig = plt.figure(num=2,figsize = (12,8))
gs1 = gridspec.GridSpec(n_participants+1, len(names))
gs1.update(wspace=0.2, hspace=0.2)
ctf_layout = mne.find_layout(all_epochs[1].info)
counter = 0
for i,n in enumerate(names):
for p in range(n_participants):
ax = fig.add_subplot(gs1[counter])
counter+=1
ax.clear()
picks_epochs = np.where([s[2]==n for s in all_epochs[p].ch_names])[0]
picks = np.where([i[2]==n for i in ctf_layout.names])[0]
[ax.plot(all_epochs[p].times*1000,ii,color=colors[p],label=labs[i],lw=0.1,alpha=0.2) for ii in all_nc[p][picks_epochs,:]]
ax.plot(all_epochs[p].times*1000,np.mean(all_nc[p][picks_epochs,:],axis=0),color=colors[p],label=labs[i],lw=1.5)
ax.set_ylim([0,100])
if i ==0:
ax.set_title('M' + str(p+1))
if i < len(names)-1:
plt.setp(ax.get_xticklabels(), visible=False)
else:
ax.set_xlabel('time (ms)')
if p == 0:
plt.setp(ax.get_yticklabels(), visible=True)
ax.set_ylabel(labs[i])
else:
plt.setp(ax.get_yticklabels(), visible=False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
if i == 2 and p == 0:
ax.set_ylabel('Explained Variance (%)\n' + labs[i])
# plot sensor locations
ax2 = fig.add_subplot(gs1[counter])
counter+=1
ax2.plot(ctf_layout.pos[:,0],ctf_layout.pos[:,1],color='gainsboro',marker='.',linestyle='',markersize=3)
ax2.plot(ctf_layout.pos[picks,0],ctf_layout.pos[picks,1],color='grey',marker='.',linestyle='',markersize=3)
ax2.axis('equal')
ax2.axis('off')
fig.savefig(f'{fig_dir}/data_quality-noiseceiling_all.pdf')
def make_main_plot(all_epochs,all_nc):
plt.close('all')
fig = plt.figure(num=1,figsize = (12,3))
gs1 = gridspec.GridSpec(1,len(names),wspace=0.1,)
ctf_layout = mne.find_layout(all_epochs[1].info)
for i,n in enumerate(names):
ax = fig.add_subplot(gs1[i])
# plot niose ceilings
picks_epochs = [np.where([s[2]==n for s in all_epochs[p].ch_names])[0] for p in range(n_participants)]
picks = np.where([i[2]==n for i in ctf_layout.names])[0]
[ax.plot(all_epochs[p].times*1000,np.mean(all_nc[p][picks_epochs[p],:],axis=0),color=colors[p],label='M'+str(p+1),lw=2) for p in range(n_participants)]
ax.set_ylim([0,90])
ax.set_xlim([all_epochs[1].times[0]*1000,all_epochs[1].times[len(all_epochs[1].times)-1]*1000])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
if i == len(names)-1:
plt.legend(frameon=False, bbox_to_anchor=(1, 0.5))
ax.set_xlabel('time (ms)')
if i ==0:
ax.set_ylabel('Explained variance (%)')
else:
plt.setp(ax.get_yticklabels(), visible=False)
# plot sensor locations
ax2 = ax.inset_axes([0.55, 0.55, 0.5, 0.5])
ax2.plot(ctf_layout.pos[:,0],ctf_layout.pos[:,1],color='darkgrey',marker='.',linestyle='',markersize=2)
ax2.plot(ctf_layout.pos[picks,0],ctf_layout.pos[picks,1],color='k',marker='.',linestyle='',markersize=2)
ax2.axis('equal')
ax2.axis('off')
ax2.set_title(labs[i],y=0.8,fontsize=14)
fig.savefig(f'{fig_dir}/data_quality-noiseceiling_avgd.pdf')
#*****************************#
### COMMAND LINE INPUTS ###
#*****************************#
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-bids_dir",
required=True,
help='path to bids root',
)
args = parser.parse_args()
bids_dir = args.bids_dir
preproc_dir = f'{bids_dir}/derivatives/preprocessed/'
sourcedata_dir = f'{bids_dir}/sourcedata/'
fig_dir = f'{bids_dir}/derivatives/figures/'
if not os.path.exists(fig_dir):
os.makedirs(fig_dir)
####### Run ########
all_epochs = load_epochs(preproc_dir)
all_nc = calculate_noise_ceiling(all_epochs)
make_supplementary_plot(all_epochs,fig_dir)
make_main_plot(all_epochs,all_nc)