1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
import numpy as np
from scipy import signal as sig

# Filtering functions
def butter_lowpass(lowcut, fs, order=8, sos=False):
    ''' Create a lowpass butterworth filter '''
    nyq = 0.5 * fs
    low = lowcut / nyq

    if sos:
        sos_out = sig.butter(order, low, analog=False, btype='low', output='sos')
        return sos_out

    b, a = sig.butter(order, low, analog=False, btype='low', output='ba')
    return b, a

def butter_highpass(highcut, fs, order=8, sos=False):
    ''' Create a highpass butterworth filter '''
    nyq = 0.5 * fs
    high = highcut / nyq

    if sos:
        sos_out = sig.butter(order, high, analog=False, btype='high', output='sos')
        return sos_out

    b, a = sig.butter(order, high, analog=False, btype='high', output='ba')
    return b, a

def butter_bandpass(lowcut, highcut, fs, order=8, sos=False):
    ''' Create a bandpass butterworth filter '''
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq

    if sos:
        sos_out = sig.butter(order, [low, high], analog=False, btype='band', output='sos')
        return sos_out

    b, a = sig.butter(order, [low, high], analog=False, btype='band', output='ba')
    return b, a


def butter_lowpass_filter(data, lowcut, fs, order=5, sos=False):
    ''' Lowpass filter the data '''
    if sos:
        sos_out = butter_lowpass(lowcut, fs, order=order, sos=sos)
        y = sig.sosfiltfilt(sos_out, data)
    else:
        b, a = butter_lowpass(lowcut, fs, order=order, sos=sos)
        y = sig.filtfilt(b, a, data)

    return y

def butter_highpass_filter(data, highcut, fs, order=5, sos=False):
    ''' Highpass filter the data '''
    if sos:
        sos_out = butter_highpass(highcut, fs, order=order, sos=sos)
        y = sig.sosfiltfilt(sos_out, data)
    else:
        b, a = butter_highpass(highcut, fs, order=order, sos=sos)
        y = sig.filtfilt(b, a, data)

    return y

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5, sos=False):
    ''' Bandpass filter the data '''
    if sos:
        sos_out = butter_bandpass(lowcut, highcut, fs, order=order, sos=sos)
        y = sig.sosfiltfilt(sos_out, data)
    else:
        b, a = butter_bandpass(lowcut, highcut, fs, order=order, sos=sos)
        y = sig.filtfilt(b, a, data)

    return y


# Functions examples from https://www.programcreek.com/python/example/100546/scipy.signal.spectrogram
def power_spectrum(signal: np.ndarray,
                   fs: int,
                   window_width: int,
                   window_overlap: int) -> (np.ndarray, np.ndarray, np.ndarray):
    """
    Computes the power spectrum of the specified signal.

    A periodic Hann window with the specified width and overlap is used.

    Parameters
    ----------
    signal: numpy.ndarray
        The input signal
    fs: int
        Sampling frequency of the input signal
    window_width: int
        Width of the Hann windows in samples
    window_overlap: int
        Overlap between Hann windows in samples

    Returns
    -------
    f: numpy.ndarray
        Array of frequency values for the first axis of the returned spectrogram
    t: numpy.ndarray
        Array of time values for the second axis of the returned spectrogram
    sxx: numpy.ndarray
        Power spectrogram of the input signal with axes [frequency, time]
    """
    f, t, sxx = spectrogram(x=signal,
                            fs=fs,
                            window=hann(window_width, sym=False),
                            noverlap=window_overlap,
                            mode="magnitude")

    return f, t, (1.0 / window_width) * (sxx ** 2)


def power_to_db(spectrum: np.ndarray,
                clip_below: float = None,
                clip_above: float = None) -> np.ndarray:
    """
    Convert a spectrogram to the Decibel scale.

    Optionally, frequencies with amplitudes below or above a certain threshold can be clipped.

    Parameters
    ----------
    spectrum: numpy.ndarray
        The spectrogram to convert
    clip_below: float, optional
        Clip frequencies below the specified amplitude in dB
    clip_above: float, optional
        Clip frequencies above the specified amplitude in dB

    Returns
    -------
    numpy.ndarray
        The spectrogram on the Decibel scale
    """
    # there might be zeros, fix them to the lowest non-zero power in the spectrogram
    epsilon = np.min(spectrum[np.where(spectrum > 0)])

    sxx = np.where(spectrum > epsilon, spectrum, epsilon)
    sxx = 10 * np.log10(sxx / np.max(sxx))

    if clip_below is not None:
        sxx = np.maximum(sxx, clip_below)

    if clip_above is not None:
        sxx = np.minimum(sxx, clip_above)

    return sxx


def my_FR(spikes: np.ndarray,
            duration: int,
            window_size: float,
            overlap: float) -> (np.ndarray, np.ndarray):
    """
    Compute the firing rate using a windowed moving average.

    Parameters
    ----------
    spikes: numpy.ndarray
        The spike times (*not* Brian2 format, in -unitless- seconds)
    duration: int
        The duration of the recording (in -unitless- seconds)
    window_size: float
        Width of the moving average window (in -unitless- seconds)
    overlap: float
        Desired overlap between the windows (percentage in [0., 1.))

    Returns
    -------
    t: numpy.ndarray
        Array of time values for the computed firing rate. These are the window centers.
    FR: numpy.ndarray
        Spikes per window (needs to be normalized)
    """

    # Calculate new sampling times
    win_step = window_size * round(1. - overlap, 4)
    fs_n = int(1/win_step)

    # First center is at the middle of the first window
    c0 = window_size/2
    cN = duration-c0

    # centers
    centers = np.arange(c0, cN+win_step, win_step)

    # Calculate windowed FR
    counts = []
    for center in centers:
        cl = center - c0
        ch = center + c0
        spike_cnt = np.count_nonzero(np.where((spikes >= cl) & (spikes < ch)))
        counts.append(spike_cnt)

    # Normalize according to window size
    # FR = (np.array(counts)/window_size)

    # return centers, spike counts, and adjusted sampling rates per window
    return centers, np.array(counts), fs_n


def my_specgram(signal: np.ndarray,
                   fs: int,
                   window_width: int,
                   window_overlap: int,
                   k: int = 1,
                   **kwargs: dict) -> (np.ndarray, np.ndarray, np.ndarray):
    """
    Computes the power spectrum of the specified signal.

    A periodic Hann window with the specified width and overlap is used.

    Parameters
    ----------
    signal: numpy.ndarray
        The input signal
    fs: int
        Sampling frequency of the input signal
    window_width: int
        Width of the Hann windows in samples
    window_overlap: int
        Overlap between Hann windows in samples
    k: int
        Used in the nfft calculation; round to (nearest power of 2)+k
    kwargs: dict
        Extra arguments to pass to signal.spectrogram(); for more info, refer to the scipy documentation.

    Returns
    -------
    f: numpy.ndarray
        Array of frequency values for the first axis of the returned spectrogram
    t: numpy.ndarray
        Array of time values for the second axis of the returned spectrogram
    Sxx: numpy.ndarray
        Power spectrogram of the input signal with axes [frequency, time]
    """
    nfft = 2**((window_width-1).bit_length()+k) # np.ceil(np.log2(window_width))
    f, t, Sxx = sig.spectrogram(x=signal,
                                fs=fs,
                                nfft=nfft,
                                detrend='constant',
                                window=sig.windows.hann(M=window_width, sym=False),
                                # nperseg=window_width,
                                noverlap=window_overlap,
                                **kwargs)

    return f, t, Sxx
    # ims = 20.*np.log10(np.abs(sshow)/10e-6) # amplitude to decibel


def my_PSD(signal: np.ndarray,
                   fs: int,
                   window_width: int,
                   window_overlap: int,
                   k: int = 1,
                   **kwargs: dict) -> (np.ndarray, np.ndarray):
    """
    Computes the Power Spectral Density (PSD) of the specified signal using Welch's method.

    A periodic Hann window with the specified width and overlap is used.

    Parameters
    ----------
    signal: numpy.ndarray
        The input signal
    fs: int
        Sampling frequency of the input signal
    window_width: int
        Width of the Hann windows in samples
    window_overlap: int
        Overlap between Hann windows in samples
    k: int
        Used in the nfft calculation; round to (nearest power of 2)+k
    kwargs: dict
        Extra arguments to pass to signal.welch(); for more info, refer to the scipy documentation.

    Returns
    -------
    f: numpy.ndarray
        Array of frequency values for the first axis of the PSD
    Pxx: numpy.ndarray
        PSD of the input signal.
    """
    nfft = 2**((window_width-1).bit_length()+k) # np.ceil(np.log2(window_width))
    f, Pxx = sig.welch(x=signal,
                        fs=fs,
                        nfft=nfft,
                        window=sig.windows.hann(M=window_width, sym=False),
                        # window='boxcar',
                        nperseg=window_width,
                        noverlap=window_overlap,
                        **kwargs)

    return f, Pxx


def my_modulation_index(sig_phase: np.ndarray,
                        sig_amp: np.ndarray,
                        nbins: int=18) -> (float, float):
    """
    Computes the Modulation Index between two signals. The formalism used is the
    one provided in the following paper:

        # REF #
        Tort AB, Komorowski R, Eichenbaum H, Kopell N. Measuring phase-amplitude coupling between neuronal oscillations of different frequencies. J Neurophysiol. 2010 Aug;104(2):1195-210. doi: 10.1152/jn.00106.2010. Epub 2010 May 12. Erratum in: J Neurophysiol. 2010 Oct;104(4):2302. PubMed PMID: 20463205; PubMed Central PMCID: PMC2941206.
        # --- #

    Modulation Index (MI) is returned as a single number; the computation is based in the Kullback-Leibler distance. More details in the paper (p. 1196-7).

    Parameters
    ----------
    sig_phase: numpy.ndarray
        The phase signal xfp(t)
    sig_amp: numpy.ndarray
        The amplitude signal xfA(t).
    nbins: int
        Number of phase bins

    Returns
    -------
    MI: float
        Modulation Index, as computed in the paper by Tort et al. (2010)
    dist_KL: float
        Kullback-Leibler distance.
    """

    # Make the bins
    bin_edges = np.linspace(-np.pi, np.pi, nbins+1)
    bin_centers = bin_edges[1:] - np.diff(bin_edges)/2

    # Allocate to bins using their indices
    idx_bin = np.digitize(sig_phase, bin_edges)
    bin_amp = np.zeros(nbins)
    for bin in np.arange(nbins):
        if np.any(idx_bin == bin):
            bin_amp[bin] = np.mean(sig_amp[idx_bin == bin])

    # Hist. normalization step - get P(j) vals
    P_amp = bin_amp / np.sum(bin_amp)

    # Kullback-Leibler distance & modulation index (MI)
    Q_amp = np.ones(nbins) / nbins

    # In the special case where observed probability in a bin is 0, this tweak
    # allows computing a meaningful KL distance nonetheless
    P_amp[np.where(P_amp == 0)] = 1e-12

    # Kullback-Leibler distance
    dist_KL = np.sum(P_amp * np.log(P_amp / Q_amp))

    # Modulation Index
    MI = dist_KL / np.log(nbins)

    return MI, dist_KL


def bandpower(data, fs, band, window_sec=None, overlap=0.9, relative=False, return_PSD=False, **kwargs):
    """
    Compute the average power of the signal x in a specific frequency band.
    Code adapted from: https://raphaelvallat.com/bandpower.html

    Parameters
    ----------
    data: 1d-array
        Input signal in the time-domain.
    fs: float
        Sampling frequency of the data.
    band: list
        Lower and upper frequencies of the band of interest.
    window_sec: float
        Length of each window in seconds.
        If None, window_sec = (1 / min(band)) * 2
    relative: boolean
        If True, return the relative power (= divided by the total power of the signal).
        If False (default), return the absolute power.

    Return
    ------
    bp : float
        Absolute or relative band power.
    """
    from scipy.signal import welch
    from scipy.integrate import simpson
    band = np.asarray(band)
    low, high = band

    # Define window length
    if window_sec is not None:
        window_width = int(window_sec * fs)
    else:
        window_width = int((2 / low) * fs)

    # Calculate window overlap
    window_overlap = int(overlap * window_width)

    # Compute the modified periodogram (Welch)
    nfft = 2**((window_width-1).bit_length()+1)
    freqs, psd = sig.welch(x=data,
                           fs=fs,
                           nfft=nfft,
                           window=sig.windows.hann(M=window_width, sym=False),
                           # window='boxcar',
                           nperseg=window_width,
                           noverlap=window_overlap,
                           **kwargs)

    # Frequency resolution
    freq_res = freqs[1] - freqs[0]

    # Find closest indices of band in frequency vector
    idx_band = np.logical_and(freqs >= low, freqs <= high)

    # Integral approximation of the spectrum using Simpson's rule.
    if psd.size > 1:
        psd = psd.squeeze()
    bp = simpson(psd[idx_band], freqs[idx_band])

    if relative:
        bp /= simpson(psd, freqs)

    if return_PSD:
        return bp, freqs, psd

    return bp