import numpy as np
import nibabel as nib
import argparse
import csv
import os
from collections import defaultdict
from smoothed_fdr import calc_plateaus
from utils import *
from plotutils import *
def load_nii(filename):
img = nib.load(filename)
return img.get_data(), img.get_affine()
def save_nii(filename, data, coords):
img = nib.Nifti1Image(data, coords)
nib.save(img, filename)
def load_edges(filename):
with open(filename, 'rb') as f:
reader = csv.reader(f)
edges = defaultdict(list)
for line in reader:
x,y = int(line[0]), int(line[1])
if y not in edges[x]:
edges[x].append(y)
if x not in edges[y]:
edges[y].append(x)
return edges
def load_shape_lookup(filename):
with open(filename, 'rb') as f:
reader = csv.reader(f)
shape = [int(x) for x in reader.next()]
lookup = {}
for line in reader:
lookup[int(line[0])] = (int(line[1]), int(line[2]), int(line[3]))
return shape, lookup
def load_weights(filename, data, lookup, fdr=None):
weights = np.loadtxt(filename, delimiter=',')
if fdr is not None:
weights = calc_fdr(weights, fdr)
for i, w in enumerate(weights):
v = lookup[i]
data[v] = w
def blob_fdr(weights_filename, posteriors_filename, data, lookup, edges, fdr):
weights = np.loadtxt(weights_filename, delimiter=',')
betas = -np.log(1.0/weights - 1.)
posteriors = np.loadtxt(posteriors_filename, delimiter=',')
plateaus = calc_plateaus(betas, edges=edges)
aggregate = np.array([posteriors[list(p)].mean() for v,p in plateaus])
discoveries = calc_fdr(aggregate, fdr)
for (v,p), d in zip(plateaus, discoveries):
weights[list(p)] = d
for i, w in enumerate(weights):
v = lookup[i]
data[v] = w
return plateaus
def thresholded_posterior_clusters(posteriors_filename, data, lookup, edges, fdr, mincluster):
posteriors = np.loadtxt(posteriors_filename, delimiter=',')
fakebetas = np.array(posteriors)
fakebetas[fakebetas < 1.0 - fdr] = 0
fakebetas[fakebetas > 0] = 1
plateaus = calc_plateaus(fakebetas, edges=edges)
for v,p in plateaus:
if len(p) < mincluster:
continue
m = posteriors[list(p)].mean()
if 1-m > fdr:
continue
print('Cluster size: {0}\tmean:{1}'.format(len(p), m))
for i in p:
data[lookup[i]] = m
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Loads a smoothed FDR result for an fMRI image and post-processes it.')
parser.add_argument('indir', help='The directory where everything is stored.')
parser.add_argument('--verbose', type=int, default=1, help='Print detailed progress information to the console. 0=none, 1=high-level only, 2=all details.')
parser.add_argument('--missingval', type=float, default=0, help='The value used to signify a missing data point in the array. Typically this is zero.')
parser.add_argument('--fdr_level', type=float, default=0.1, help='The false discovery rate level to use when reporting discoveries.')
parser.add_argument('--mincluster', type=int, default=30, help='The minimum size a cluster must be to not be filtered out.')
parser.add_argument('--plot', action='store_true', help='If specified, will generate 2d slice plots.')
parser.set_defaults(plot=False)
args = parser.parse_args()
expdir = args.indir + ('' if args.indir.endswith('/') else '/')
if args.verbose:
print('Loading raw data from {0}'.format(expdir+'data.nii.gz'))
data, coords = load_nii(expdir+'data.nii.gz')
if args.verbose:
print('Loading edges from {0}'.format(expdir+'edges.csv'))
edges = load_edges(expdir+'edges.csv')
if args.verbose:
print('Loading shape and lookup data from {0}'.format(expdir+'lookup.csv'))
shape, lookup = load_shape_lookup(expdir+'lookup.csv')
smoothdata = np.zeros(shape)
smoothposts = np.zeros(shape)
smoothdiscs = np.zeros(shape)
if args.verbose:
print('Original data shape: {0} Smoothed data shape: {1} (should be the same)'.format(data.shape, smoothdata.shape))
if args.verbose:
print('Loading smoothed weights from {0}'.format(expdir + 'weights.csv'))
load_weights(expdir + 'weights.csv', smoothdata, lookup)
if args.verbose:
print('Loading posteriors from {0}'.format(expdir + 'posteriors.csv'))
load_weights(expdir + 'posteriors.csv', smoothposts, lookup)
if args.verbose:
print('Filtering down to a local FDR threshold of {0}'.format(args.fdr_level))
print('smoothdiscs size: {0}'.format(smoothdiscs.shape))
thresholded_posterior_clusters(expdir + 'posteriors.csv', smoothdiscs, lookup, edges, args.fdr_level, args.mincluster)
# if args.verbose:
# print 'Loading discoveries at a FDR level of {0}'.format(args.fdr_level)
# load_weights(expdir + 'posteriors.csv', smoothdiscs, lookup, fdr=args.fdr_level)
#plateaus = blob_fdr(expdir + 'weights.csv', expdir + 'posteriors.csv', smoothdiscs, lookup, edges, args.fdr_level)
# if args.verbose:
# print '# blobs: {0}.\nSaving blob size vs. average posterior to {1}'.format(len(plateaus), expdir + 'img/plateau_sizes_vs_posteriors.pdf')
# plot_plateau_sizes_vs_posteriors(plateaus, np.loadtxt(expdir+'posteriors.csv', delimiter=','), expdir + 'img/plateau_sizes_vs_posteriors.pdf')
if args.verbose:
print('Saving .nii.gz versions')
save_nii(expdir + 'weights.nii.gz', smoothdata, coords)
save_nii(expdir + 'posteriors.nii.gz', smoothposts, coords)
save_nii(expdir + 'thresholded_posterior_clusters.nii.gz', smoothdiscs, coords)
# if not os.path.exists(expdir + 'thresholded_posterior_clusters/'):
# os.makedirs(expdir + 'thresholded_posterior_clusters/')
# for i in xrange(smoothdiscs.shape[0]):
# np.savetxt(expdir + 'thresholded_posterior_clusters/{0}.csv'.format(i), smoothdiscs[i], delimiter=',')
if args.plot:
if args.verbose:
print('Plotting to {0}'.format(expdir + 'img/'))
data[np.where(data == args.missingval)] = np.nan
smoothdata[np.where(data == args.missingval)] = np.nan
smoothposts[np.where(data == args.missingval)] = np.nan
# Create the image directory if it doesn't already exist
if not os.path.exists(expdir + 'img/'):
os.makedirs(expdir + 'img/')
for i in range(3):
# Create the axis directory if it doesn't already exist
if not os.path.exists(expdir + 'img/' + str(i)):
os.makedirs(expdir + 'img/' + str(i))
# Plot the 3D image by taking 2D slices along the i'th axis
plot_3d(expdir + 'img/' + str(i) + '/{0:04d}.pdf', data, weights=smoothdata, posteriors=smoothposts, discoveries=smoothdiscs, axis=i)