https://github.com/ViCCo-Group/THINGS-data
Tip revision: 2d95c15d3a2cc5984ffd4a9a2c4ad3496847ca9d authored by Oliver Contier on 28 February 2023, 15:15:53 UTC
fixed howtocite
fixed howtocite
Tip revision: 2d95c15
step1_preprocessing.py
#!/usr/bin/env python3
"""
@ Lina Teichmann
INPUTS:
call from command line with following inputs:
-participant
-bids_dir
OUTPUTS:
epoched and cleaned data will be written into the preprocessing directory
NOTES:
This script contains the following preprocessing steps:
- channel exclusion (one malfunctioning channel ('MRO11-1609'), based on experimenter notes)
- filtering (0.1 - 40Hz)
- epoching (-100 - 1300ms) --> based on onsets of the optical sensor
- baseline correction (zscore)
- downsampling (200Hz)
Use preprocessed data ("preprocessed_P{participant}-epo.fif") saved in preprocessing directory for the next steps
"""
import mne, os
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
#*****************************#
### PARAMETERS ###
#*****************************#
n_sessions = 12
trigger_amplitude = 64
l_freq = 0.1
h_freq = 40
pre_stim_time = -0.1
post_stim_time = 1.3
std_deviations_above_below = 4
output_resolution = 200
trigger_channel = 'UPPT001'
#*****************************#
### HELPER FUNCTIONS ###
#*****************************#
def setup_paths(meg_dir, session):
run_paths,event_paths = [],[]
for file in os.listdir(f'{meg_dir}/ses-{str(session).zfill(2)}/meg/'):
if file.endswith(".ds") and file.startswith("sub"):
run_paths.append(os.path.join(f'{meg_dir}/ses-{str(session).zfill(2)}/meg/', file))
if file.endswith("events.tsv") and file.startswith("sub"):
event_paths.append(os.path.join(f'{meg_dir}/ses-{str(session).zfill(2)}/meg/', file))
run_paths.sort()
event_paths.sort()
return run_paths, event_paths
def read_raw(curr_path,session,run,participant):
raw = mne.io.read_raw_ctf(curr_path,preload=True)
# signal dropout in one run -- replacing values with median
if participant == '1' and session == 11 and run == 4:
n_samples_exclude = int(0.2/(1/raw.info['sfreq']))
raw._data[:,np.argmin(np.abs(raw.times-13.4)):np.argmin(np.abs(raw.times-13.4))+n_samples_exclude] = np.repeat(np.median(raw._data,axis=1)[np.newaxis,...], n_samples_exclude, axis=0).T
elif participant == '2' and session == 10 and run == 2:
n_samples_exclude = int(0.2/(1/raw.info['sfreq']))
raw._data[:,np.argmin(np.abs(raw.times-59.8)):np.argmin(np.abs(raw.times-59.8))+n_samples_exclude] = np.repeat(np.median(raw._data,axis=1)[np.newaxis,...], n_samples_exclude, axis=0).T
raw.drop_channels('MRO11-1609')
return raw
def read_events(event_paths,run,raw):
# load event file that has the corrected onset times (based on optical sensor and replace in the events variable)
event_file = pd.read_csv(event_paths[run],sep='\t')
event_file.value.fillna(999999,inplace=True)
events = mne.find_events(raw, stim_channel=trigger_channel,initial_event=True)
events = events[events[:,2]==trigger_amplitude]
events[:,0] = event_file['sample']
events[:,2] = event_file['value']
return events
def concat_epochs(raw, events, epochs):
if epochs:
epochs_1 = mne.Epochs(raw, events, tmin = pre_stim_time, tmax = post_stim_time, picks = 'mag',baseline=None)
epochs_1.info['dev_head_t'] = epochs.info['dev_head_t']
epochs = mne.concatenate_epochs([epochs,epochs_1])
else:
epochs = mne.Epochs(raw, events, tmin = pre_stim_time, tmax = post_stim_time, picks = 'mag',baseline=None)
return epochs
def baseline_correction(epochs):
baselined_epochs = mne.baseline.rescale(data=epochs.get_data(),times=epochs.times,baseline=(None,0),mode='zscore',copy=False)
epochs = mne.EpochsArray(baselined_epochs, epochs.info, epochs.events, epochs.tmin,event_id=epochs.event_id)
return epochs
def stack_sessions(sourcedata_dir,preproc_dir,participant,session_epochs,output_resolution):
for epochs in session_epochs:
epochs.info['dev_head_t'] = session_epochs[0].info['dev_head_t']
all_epochs = mne.concatenate_epochs(epochs_list = session_epochs, add_offset=True)
all_epochs.metadata = pd.read_csv(f'{sourcedata_dir}/sample_attributes_P{str(participant)}.csv')
all_epochs.decimate(decim=(1200/output_resolution))
all_epochs.save(f'{preproc_dir}/preprocessed_P{str(participant)}-epo.fif', overwrite=True)
print(all_epochs.info)
#*****************************#
### FUNCTION TO RUN PREPROCESSING ###
#*****************************#
def run_preprocessing(meg_dir,session,participant):
epochs = []
run_paths, event_paths = setup_paths(meg_dir, session)
for run, curr_path in enumerate(run_paths):
raw = read_raw(curr_path,session,run, participant)
events = read_events(event_paths,run,raw)
raw.filter(l_freq=l_freq,h_freq=h_freq)
epochs = concat_epochs(raw, events, epochs)
epochs.drop_bad()
print(epochs.info)
epochs = baseline_correction(epochs)
return epochs
#*****************************#
### COMMAND LINE INPUTS ###
#*****************************#
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-participant",
required=True,
help='participant bids ID (e.g., 1)',
)
parser.add_argument(
"-bids_dir",
required=True,
help='path to bids root',
)
args = parser.parse_args()
bids_dir = args.bids_dir
participant = args.participant
meg_dir = f'{bids_dir}/sub-BIGMEG{participant}/'
sourcedata_dir = f'{bids_dir}/sourcedata/'
preproc_dir = f'{bids_dir}/derivatives/preprocessed/'
if not os.path.exists(preproc_dir):
os.makedirs(preproc_dir)
####### Run preprocessing ########
session_epochs = Parallel(n_jobs=12, backend="multiprocessing")(delayed(run_preprocessing)(meg_dir,session,participant) for session in range(1,n_sessions+1))
stack_sessions(sourcedata_dir,preproc_dir,participant,session_epochs,output_resolution)