https://github.com/PeyracheLab/pynacollada
Tip revision: 4dd1adddad54627601f7567354fa7f0af020fc7d authored by Guillaume Viejo on 05 July 2023, 20:05:21 UTC
Merge pull request #22 from slcalgin/main
Merge pull request #22 from slcalgin/main
Tip revision: 4dd1add
eeg_processing.py
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-01-17 14:10:40
# @Last Modified by: gviejo
# @Last Modified time: 2022-04-11 17:57:39
import numpy as np
import pynapple as nap
from scipy.signal import butter, lfilter, filtfilt
def _butter_bandpass(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return b, a
def _butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
b, a = _butter_bandpass(lowcut, highcut, fs, order=order)
y = lfilter(b, a, data)
return y
def bandpass_filter(data, lowcut, highcut, fs, order=4):
"""
Bandpass filtering the LFP.
Parameters
----------
data : Tsd/TsdFrame
Description
lowcut : TYPE
Description
highcut : TYPE
Description
fs : TYPE
Description
order : int, optional
Description
Raises
------
RuntimeError
Description
"""
time_support = data.time_support
time_index = data.as_units('s').index.values
if type(data) is nap.TsdFrame:
tmp = np.zeros(data.shape)
for i,c in enumerate(data.columns):
tmp[:,i] = bandpass_filter(data[c], lowcut, highcut, fs, order)
return nap.TsdFrame(
t = time_index,
d = tmp,
time_support = time_support,
time_units = 's',
columns = data.columns)
elif type(data) is nap.Tsd:
flfp = _butter_bandpass_filter(data.values, lowcut, highcut, fs, order)
return nap.Tsd(
t=time_index,
d=flfp,
time_support=time_support,
time_units='s')
else:
raise RuntimeError("Unknow format. Should be Tsd/TsdFrame")
def detect_oscillatory_events(lfp, epoch, freq_band, thres_band, duration_band, min_inter_duration, wsize=51):
"""
Simple helper for detecting oscillatory events (e.g. ripples, spindles)
Parameters
----------
lfp : Tsd
Should be a single channel raw lfp
epoch : IntervalSet
The epoch for restricting the detection
freq_band : tuple
The (low, high) frequency to bandpass the signal
thres_band : tuple
The (min, max) value for thresholding the normalized squared signal after filtering
duration_band : tuple
The (min, max) duration of an event in second
min_inter_duration : float
The minimum duration between two events otherwise they are merged (in seconds)
wsize : int, optional
The size of the window for digitial filtering
Returns
-------
IntervalSet
The intervalSet detected
Tsd
Timeseries containing the peaks of the oscillations
"""
lfp = lfp.restrict(epoch)
frequency = lfp.rate
signal = bandpass_filter(lfp, freq_band[0], freq_band[1], frequency)
squared_signal = np.square(signal.values)
window = np.ones(wsize)/wsize
nSS = filtfilt(window, 1, squared_signal)
nSS = (nSS - np.mean(nSS))/np.std(nSS)
nSS = nap.Tsd(t = signal.index.values, d=nSS, time_support=epoch)
# Round1 : Detecting Oscillation Periods by thresholding normalized signal
nSS2 = nSS.threshold(thres_band[0], method='above')
nSS3 = nSS2.threshold(thres_band[1], method='below')
# Round 2 : Excluding oscillation whose length < min_duration and greater than max_duration
osc_ep = nSS3.time_support
osc_ep = osc_ep.drop_short_intervals(duration_band[0], time_units = 's')
osc_ep = osc_ep.drop_long_intervals(duration_band[1], time_units = 's')
# Round 3 : Merging oscillation if inter-oscillation period is too short
osc_ep = osc_ep.merge_close_intervals(min_inter_duration, time_units = 's')
osc_ep = osc_ep.reset_index(drop=True)
# Extracting Oscillation peak
osc_max = []
osc_tsd = []
for s, e in osc_ep.values:
tmp = nSS.loc[s:e]
osc_tsd.append(tmp.idxmax())
osc_max.append(tmp.max())
osc_max = np.array(osc_max)
osc_tsd = np.array(osc_tsd)
osc_tsd = nap.Tsd(t=osc_tsd, d=osc_max, time_support=epoch)
return osc_ep, osc_tsd
# def downsample(tsd, up, down):
# import scipy.signal
# dtsd = scipy.signal.resample_poly(tsd.values, up, down)
# dt = tsd.as_units('s').index.values[np.arange(0, tsd.shape[0], down)]
# if len(tsd.shape) == 1:
# return nap.Tsd(dt, dtsd, time_units = 's')
# elif len(tsd.shape) == 2:
# return nap.TsdFrame(dt, dtsd, time_units = 's', columns = list(tsd.columns))
# def getPeaksandTroughs(lfp, min_points):
# """
# At 250Hz (1250/5), 2 troughs cannont be closer than 20 (min_points) points (if theta reaches 12Hz);
# """
# import scipy.signal
# if isinstance(lfp, nap.time_series.Tsd):
# troughs = nap.Tsd(lfp.as_series().iloc[scipy.signal.argrelmin(lfp.values, order =min_points)[0]], time_units = 'us')
# peaks = nap.Tsd(lfp.as_series().iloc[scipy.signal.argrelmax(lfp.values, order =min_points)[0]], time_units = 'us')
# tmp = nap.Tsd(troughs.realign(peaks, align = 'next').as_series().drop_duplicates('first')) # eliminate double peaks
# peaks = peaks[tmp.index]
# tmp = nap.Tsd(peaks.realign(troughs, align = 'prev').as_series().drop_duplicates('first')) # eliminate double troughs
# troughs = troughs[tmp.index]
# return (peaks, troughs)
# elif isinstance(lfp, nap.time_series.TsdFrame):
# peaks = nap.TsdFrame(lfp.index.values, np.zeros(lfp.shape))
# troughs = nap.TsdFrame(lfp.index.values, np.zeros(lfp.shape))
# for i in lfp.keys():
# peaks[i], troughs[i] = getPeaksandTroughs(lfp[i], min_points)
# return (peaks, troughs)
# def getPhase(lfp, fmin, fmax, nbins, fsamp, power = False):
# """ Continuous Wavelets Transform
# return phase of lfp in a Tsd array
# """
# from Wavelets import MyMorlet as Morlet
# if isinstance(lfp, nap.time_series.TsdFrame):
# allphase = nap.TsdFrame(lfp.index.values, np.zeros(lfp.shape))
# allpwr = nap.TsdFrame(lfp.index.values, np.zeros(lfp.shape))
# for i in lfp.keys():
# allphase[i], allpwr[i] = getPhase(lfp[i], fmin, fmax, nbins, fsamp, power = True)
# if power:
# return allphase, allpwr
# else:
# return allphase
# elif isinstance(lfp, nap.time_series.Tsd):
# cw = Morlet(lfp.values, fmin, fmax, nbins, fsamp)
# cwt = cw.getdata()
# cwt = np.flip(cwt, axis = 0)
# wave = np.abs(cwt)**2.0
# phases = np.arctan2(np.imag(cwt), np.real(cwt)).transpose()
# cwt = None
# index = np.argmax(wave, 0)
# # memory problem here, need to loop
# phase = np.zeros(len(index))
# for i in range(len(index)) : phase[i] = phases[i,index[i]]
# phases = None
# if power:
# pwrs = cw.getpower()
# pwr = np.zeros(len(index))
# for i in range(len(index)):
# pwr[i] = pwrs[index[i],i]
# return nap.Tsd(lfp.index.values, phase), nap.Tsd(lfp.index.values, pwr)
# else:
# return nap.Tsd(lfp.index.values, phase)