https://github.com/samsydco/HBN
Raw File
Tip revision: 278127d07b721c73679c11d0d1836631df778323 authored by samsydco on 16 May 2022, 17:49:03 UTC
Update README
Tip revision: 278127d
8_HMM_stats.py
#!/usr/bin/env python3

'''
# In ROIs with similar k's and high LL's in both or one age group:

1) Get patterns by HMM([Young_train, Old_train])
2) LLs = find_events(Young_test), find_events(Old_test)
3) Pick K by maximizing avg LLs
4A) LL diff at this K
4B) AUC diff at this K
Permutations only for 4
Perm: find_events(Perm1), find_events(Perm2)
All of this train/test set loop
'''

import os
import glob
import tqdm
import numpy as np
import deepdish as dd
import brainiak.eventseg.event
from HMM_settings import *

bins = np.arange(nbinseq)
nbins = len(bins)
bin_tmp = [0,4]
task = 'DM'
nTR_ = nTR[0]
nshuff2perm=1000


for seed in tqdm.tqdm(seeds):
	seedsavedir = HMMsavedir+seed
	if not os.path.exists(seedsavedir): os.makedirs(seedsavedir)
	for roi_short in tqdm.tqdm(ROIl):
		if os.path.exists(seedsavedir+'/'+roi_short+'.h5'):
			roidict = dd.io.load(seedsavedir+'/'+roi_short+'.h5')
			nshuff_ = len(roidict['ll_diff'][1:])
			p_ll_ = np.sum(abs(roidict['ll_diff'][0])<abs(roidict['ll_diff'][1:]) )/nshuff_
			p_auc = np.sum(abs(roidict['auc_diff'][0])<abs(roidict['auc_diff'][1:]) )/nshuff_
			nshuff2 = nshuff2perm + nshuff_
			nshuff_all = 0
			p_ll_all = p_aucall = 1
			if os.path.exists(pvals_file):
				nshuff_all = len(dd.io.load(pvals_file, '/roidict/'+roi_short+'/auc_diff/shuff'))
				p_ll_all = dd.io.load(pvals_file, '/roidict/'+roi_short+'/ll_diff/p')
				p_aucall = dd.io.load(pvals_file, '/roidict/'+roi_short+'/auc_diff/p')
				if nshuff_all > nshuff2: nshuff2 = nshuff_all
			shuffl = np.arange(nshuff_+1,nshuff2+1)
		else:
			roidict = {'bin_'+str(b):{} for b in bins}
			for b in bins:
				roidict['bin_'+str(b)]['D'] =  dd.io.load(roidir+seed+'/'+roi_short+'.h5', '/DM/bin_'+str(b)+'/D')
			nshuff2 = nshuff_ = nshuff
			shuffl = np.arange(nshuff2+1)
			p_ll_ = p_auc = 0 #Default = Do the test
		if ((p_ll_<0.05 or p_auc<0.05) and nshuff_ < nshuff2perm) \
		or (p_ll_== 0 or p_auc == 0) or (nshuff_ < nshuff_all and (p_ll_all < 0.05 or p_aucall < 0.05)):
			D = [roidict['bin_'+str(b)]['D'] for b in bins]
			if not os.path.exists(seedsavedir+'/'+roi_short+'.h5'):
				tune_ll = np.zeros((nbins,nsplit,len(k_list)))
				for split,Ls in enumerate(kf.split(np.arange(nsub),y)):
					Dtrain = [np.mean(d[Ls[0]],axis=0).T for d in [D[bi] for bi in bin_tmp]]
					Dtest  = [np.mean(d[Ls[1]],axis=0).T for d in D]
					for ki,k in enumerate(k_list):
						hmm = brainiak.eventseg.event.EventSegment(n_events=k)
						hmm.fit(Dtrain)
						for b in bins:
							_, tune_ll[b,split,ki] = hmm.find_events(Dtest[b])
				#Now calculating best_k from average ll:
				best_k = k_list[np.argmax(np.mean(np.mean(tune_ll,0),0))]
				roidict['best_k'] = best_k
				tune_seg = np.zeros((len(shuffl),nbins,nsplit,nTR_,best_k))
				tune_ll = np.zeros((len(shuffl),nbins,nsplit))
			else:
				best_k = roidict['best_k']
				tune_seg = np.append(roidict['tune_seg_perm'], np.zeros((len(shuffl),nbins,nsplit,nTR_,best_k)),axis=0)
				tune_ll = np.append(roidict['tune_ll_perm'], np.zeros((len(shuffl),nbins,nsplit)),axis=0)
			for split,Ls in enumerate(kf.split(np.arange(nsub),y)):
				Dtrain = [np.mean(d[Ls[0]],axis=0).T for d in [D[bi] for bi in bin_tmp]]
				Dtest_all  = np.concatenate([d[Ls[1]] for d in D])
				nsubLO = len(Ls[1])
				subl = np.arange(nsubLO*nbins) # subject list to be permuted!
				hmm = brainiak.eventseg.event.EventSegment(n_events=best_k)
				hmm.fit(Dtrain)
				for shuff in shuffl:		
					for bi in range(nbins):
						if shuff != 0:
							# RANDOMIZE
							subl = np.random.permutation(nsubLO*nbins)
						idx = np.arange(bi*nsubLO,(bi+1)*nsubLO)
						Dtest = np.mean(Dtest_all[subl[idx]],axis=0).T
						tune_seg[shuff,bi,split], tune_ll[shuff,bi,split] = hmm.find_events(Dtest)
			roidict['tune_ll_perm'] = tune_ll
			roidict['tune_seg_perm'] = tune_seg
	
			roidict['E_k'] = np.zeros((nshuff2+1,nbins,nTR_))
			roidict['auc'] = np.zeros((nshuff2+1,nbins))
			roidict['auc_diff'] = np.zeros(nshuff2+1)
			roidict['ll_diff'] = np.zeros(nshuff2+1)
			for shuff in range(nshuff2+1):
				for bi in range(nbins):
					roidict['E_k'][shuff,bi] = np.dot(np.mean(tune_seg[shuff,bi],axis=0), np.arange(best_k)+1)
					roidict['auc'][shuff,bi] = roidict['E_k'][shuff,bi].sum()
				roidict['auc_diff'][shuff] = ((roidict['auc'][shuff,-1] -roidict['auc'][shuff,0])/(best_k-1))*TR
				roidict['ll_diff'][shuff] = np.mean(tune_ll[shuff,-1] - tune_ll[shuff,0])/nTR_
			dd.io.save(seedsavedir+'/'+roi_short+'.h5',roidict)
				
			
		
		
	
	
back to top