https://github.com/samsydco/HBN
Raw File
Tip revision: 7161591e3032afc987821f6fd80181a20976aec2 authored by Samantha Cohen on 16 August 2022, 18:12:06 UTC
Merge branch 'master' of https://github.com/samsydco/HBN
Tip revision: 7161591
6_ISC_test.py
#!/usr/bin/env python3

# 1) Compute many leave-one-out values(hundreds?), by randomly selecting subjects and correlating against the N-1 other subjects.
# 2) Compute many split-half values, by randomly selecting an N/2 split of subjects and correlating the mean timecourses of the halves.
# 3) Compute many pairwise values, by randomly selecting two subjects and correlating their timecourses.
# In all cases - time how long it takes to collect these values.
# then, compute f using subsets of each of these values

import time
import h5py
import tqdm
import numpy as np
import deepdish as dd
from scipy.stats import zscore
from scipy.spatial.distance import squareform
from random import shuffle
import matplotlib.pyplot as plt
from settings import *
# Only using 233 subj
subord = dd.io.load(metaphenopath+'pheno_2019-05-28.h5',['/subs'])[0]
ISCf = ISCpath+'old_ISC/ISC_2019-05-28.h5'
n_subj = len(subord)
n_vox = 5
ISCversions = ['Loo','SH','Pair']

# Some math to convert between correlation values
def corr_convert(r,N,corrtype='SH'):
	if corrtype == 'SH':
		f = 2*r/(N*(1-r))
	elif corrtype == 'Pair':
		f = r/(1-r)
	else: 
		f = (N*np.square(r)+np.sqrt((N**2)*np.power(r,4,dtype=np.float16)+4*np.square(r)*(N-1)*(1-np.square(r))))/(2*(N-1)*(1-np.square(r)))
	r_pw = f/(f+1)
	r_sh = (N*f)/(N*f+2)
	r_loo = (np.sqrt(N-1)*f)/(np.sqrt(f+1)*np.sqrt((N-1)*f+1))
	return r_pw,r_sh,r_loo

for task in ['DM','TP']:
	print(task)
	non_nan_verts = np.where(~np.isnan(np.concatenate([dd.io.load(subord[0],['/'+task+'/L'])[0], dd.io.load(subord[0],['/'+task+'/R'])[0]], axis=0))[:,0])[0]
	dictall = {k:{'ISC':None,'Time':None} for k in ISCversions}
	_,n_time = dd.io.load(subord[0],['/'+task+'/L'])[0].shape
	dictall['verts'] = non_nan_verts[np.random.choice(len(non_nan_verts),n_vox,replace=False)]
	D = np.empty((n_vox,n_time,n_subj),dtype='float16')
	keys = list(h5py.File(ISCf)[task]['data'].keys())
	for key in keys:
		D[:,0+250*int(key):250+250*int(key),:] = dd.io.load(ISCf,['/'+task+'/data/'+key], sel=dd.aslice[dictall['verts'],:,:])[0]
	# Loo
	print('Leave one out...')
	dictall['Loo']['ISC'] = np.zeros((n_vox,n_subj),dtype='float16')
	dictall['Loo']['Time'] = []
	# Loop across choice of leave-one-out subject
	for loo_subj in tqdm.tqdm(range(n_subj)):
		t = time.process_time()
		group = np.zeros((n_vox,n_time),dtype='float16')
		groupn = np.ones((n_vox,n_time),dtype='int')*n_subj-1
		for i in range(n_subj):
			if i != loo_subj:
				group = np.nansum(np.stack((group,D[:,:,i])),axis=0)
				nanverts = np.argwhere(np.isnan(D[:,:,i]))
				groupn[nanverts[:, 0],nanverts[:,1]] = groupn[nanverts[:,0],nanverts[:,1]]-1
		group = zscore(group/groupn,axis=1)
		subj = zscore(D[:, :, loo_subj],axis=1)
		dictall['Loo']['ISC'][:,loo_subj] = np.sum(np.multiply(group,subj),axis=1)/(n_time-1)
		dictall['Loo']['Time'].append(time.process_time() - t)
	# SH
	print('Split Halves...')
	dictall['SH']['ISC'] = np.zeros((n_vox,n_subj),dtype='float16')
	dictall['SH']['Time'] = []
	subjl = np.arange(n_subj)
	for sh in tqdm.tqdm(range(n_subj)):
		t = time.process_time()
		shuffle(subjl)
		groups = np.zeros((2,n_vox,n_time),dtype='float16')
		for h in [0,1]:
			group = np.zeros((n_vox,n_time),dtype='float16')
			groupn = np.ones((n_vox,n_time),dtype='int')*n_subj
			for i in subjl[0+n_subj//2*h:n_subj//2+n_subj//2*h]:
				group = np.nansum(np.stack((group,D[:,:,i])),axis=0)
				nanverts = np.argwhere(np.isnan(D[:,:,i]))
				groupn[nanverts[:, 0],nanverts[:,1]] = groupn[nanverts[:,0],nanverts[:,1]]-1
			groups[h] = zscore(group/groupn,axis=1)
		dictall['SH']['ISC'][:,sh] = np.sum(np.multiply(groups[0],groups[1]),axis=1)/(n_time-1)
		dictall['SH']['Time'].append(time.process_time() - t)
	# Pairwise
	print('Pairwise...')
	dictall['Pair']['Time'] = []
	for i in range(n_subj):
		voxel_iscs = []
		t = time.process_time()
		for v in np.arange(n_vox):
			voxel_data = D[v, :, :i+1].T
			# Correlation matrix for all pairs of subjects (triangle)
			iscs = squareform(np.corrcoef(voxel_data), checks=False)
			voxel_iscs.append(iscs)
		dictall['Pair']['Time'].append(time.process_time() - t)
	dictall['Pair']['ISC'] = np.column_stack(voxel_iscs)

	n_it = 10
	for g in ISCversions:
		i = dictall[g]
		i['TimeCum'] = []
		i['ISCCum'] = np.zeros((n_subj-1,n_vox,n_it))
		i['f'] = np.zeros((n_subj-1,n_vox,n_it))
		for s in np.arange(2,n_subj+1):
			i['TimeCum'].append(np.sum(i['Time'][:s]))
			for it in range(n_it):
				if g != 'Pair':
					randsubjs = np.random.choice(n_subj,s,replace=False)
					i['ISCCum'][s-2,:,it] = np.nanmean(i['ISC'][:,randsubjs],axis=1)
				else:
					randsubjs = np.random.choice(int((n_subj*n_subj-n_subj)/2),int((s*s-s)/2),replace=False)
					i['ISCCum'][s-2,:,it] = np.nanmean(i['ISC'][randsubjs,:],axis=0)
			N = n_subj
			r = i['ISCCum'][s-2]
			if g == 'Pair':
				i['f'][s-2] = r/(1-r)
			if g == 'SH':
				i['f'][s-2] = 2*r/(N*(1-r))
			if g == 'Loo':
				i['f'][s-2] = (N*np.square(r)+ \
				 np.sqrt((N**2)*np.power(r,4,dtype=np.float16)+4*np.square(r)*(N-1)*(1-np.square(r))))/(2*(N-1)*(1-np.square(r)))
		
	figsubj = 200
	fig = plt.figure()
	for v in range(n_vox):
		ax = fig.add_subplot(n_vox,1,v+1)
		if v == 0:
			ax.set_title(task)
		for g in ISCversions:
			i = dictall[g]
			# plot Time vs Accuracy:
			final_f = np.mean(i['f'][-1,v,:])
			y = np.mean(i['f'][:figsubj,v,:]-final_f,axis=1)
			error = np.max(i['f'][:figsubj,v,:]-final_f,axis=1)
			ax.plot(i['TimeCum'][:figsubj],y,label=g)
			ax.fill_between(i['TimeCum'][:figsubj], y-error, y+error,alpha=0.2)
			ax.set_ylabel('f acc\nvert =\n'+str(dictall['verts'][v]),size=7)
			#if g == 'Pair':
			#	ax.set_ylim(y[0]-error[0],y[0]+error[0])
		ax.set_xlim(min(dictall['Pair']['TimeCum'][:figsubj]),max(dictall['SH']['TimeCum'][:figsubj]))
		if v == n_vox-1:
			ax.legend(bbox_to_anchor=(0.7,1,0.5,3.5))
			ax.set_xlabel('Time [s]')
	fig.savefig(figurepath+'TimeVsFacc_'+task+'_'+str(figsubj)+'.png', bbox_inches = "tight",dpi=300)
		
	for figsubj in [25,50,n_subj-1]:
		fig = plt.figure()
		for v in range(n_vox):
			ax0 = fig.add_subplot(n_vox+1,1,v+1)
			if v == 0:
				ax0.set_title(task)
			for g in ISCversions:
				i = dictall[g]
				# plot subj vs f for vox:
				ax0.plot(np.arange(1,figsubj+1),np.mean(i['f'][:figsubj,v,:],axis=1))
				final_f = np.mean((dictall['SH']['f'][-1,v,:]+dictall['Loo']['f'][-1,v,:]+dictall['Pair']['f'][-1,v,:])/3)
				ax0.plot([figsubj-1,figsubj],[final_f,final_f],'k-')
				ax0.set_ylabel('f\nvert =\n'+str(dictall['verts'][v]),size=7)
				ax0.set_xlim(2,figsubj)
				if v == n_vox-1:
					ax1 = fig.add_subplot(n_vox+1,1,n_vox+1)
					ax1.plot(np.arange(1,figsubj+1),i['TimeCum'][:figsubj],label=g) # plot subj vs compute time for vox
					ax1.set_xlim(2,figsubj)
		ax1.legend(bbox_to_anchor=(0.7,1,0.5,3.5))
		ax1.set_xlabel('Subjects')
		ax1.set_ylabel('Time [s]')
		fig.savefig(figurepath+'FvsCompT_'+task+'_'+str(figsubj)+'.png', bbox_inches = "tight",dpi=300)
		
	with h5py.File(ISCpath+'ISC_test.h5') as hf:
		grp = hf.create_group(task)
		if type(i) == dict:
			for g,i in dictall.items():
				ds = grp.create_group(g)
				for gi,ii in i.items():
					ds.create_dataset(gi,data=ii)
			else:
				grp.create_dataset(g,data=i)
				

	
	
	
back to top