https://github.com/deniclab/pyto_segmenter
Tip revision: 54bd5b9c2b90300a88976936a40f18779db0bbf7 authored by nrweir on 30 October 2017, 20:35:37 UTC
fixing log issues with CellSegment.py
fixing log issues with CellSegment.py
Tip revision: 54bd5b9
CellSegment.py
"""Classes and methods for segmenting cells based on cytosolic fluorescence."""
# IMPORT DEPENDENCIES
import os
import sys
import pickle
import time
import pandas as pd
import numpy as np
from skimage import io
from skimage.morphology import watershed
from scipy.ndimage.filters import gaussian_filter, maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
from scipy.ndimage.morphology import binary_dilation, binary_fill_holes
from scipy.ndimage.morphology import distance_transform_edt
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
class CellSegmentObj:
""" Segmentation data from a cytosolic fluorescence image.
Attributes:
filename (str): the filename for the original raw fluroescence
image.
raw_img (ndarray): 3d array containing pixel intensities from
the raw fluroescence image.
slices: the number of z-slices in the raw image.
height: the height of the raw image.
width: the width of the raw image.
gaussian_img (ndarray): gaussian-filtered fluorescence image.
threshold (int): the cutoff used for binary thresholding of
the gaussian image.
threshold_img (ndarray): the output from applying the threshold
to the gaussian image ndarray.
filled_img (ndarray): the product of using three-dimensional
binary hole filling (as implemented in
scipy.ndimage.morphology) on the thresholded image
dist_map (ndarray): a 3D euclidean distance transform of the
filled image.
smooth_dist_map (ndarray): gaussian smoothed distance map.
maxima (ndarray): a boolean array indicating the positions of
local maxima in the smoothed distance map.
labs (ndarray): an int array with each local maximum in maxima
assigned a unique integer value, and non-maxima assigned 0.
required for watershed implementation at the next step.
watershed_output (ndarray): a 3D array in which each object is
assigned a unique integer value.
filled_cells (ndarray): the product of performing 2D hole-filling
object-by-object on the watershed_output array (holes are only
filled if they are completely surrounded by the same cell;
done to eliminate "tunnels" that can arise from vacuoles).
final_cells (ndarray): the product of cleaning up the filled_cells
segmented image using a voting filter. see the CellSegmenter
reassign_pix_obs method for details.
"""
def __init__(self, f_directory, filename, raw_img, gaussian_img, threshold,
threshold_img, filled_img, dist_map, smooth_dist_map,
maxima, labs, watershed_output, filled_cells, final_cells,
obj_nums, volumes):
"""Initialize the CellSegmentObject with segmentation data."""
print('creating CellSegmentObject...')
self.f_directory = f_directory
self.filename = os.path.basename(filename).lower()
self.raw_img = raw_img.astype('uint16')
self.gaussian_img = gaussian_img.astype('uint16')
self.threshold = int(threshold)
self.threshold_img = threshold_img.astype('uint16')
self.filled_img = filled_img.astype('uint16')
self.dist_map = dist_map.astype('uint16')
self.smooth_dist_map = smooth_dist_map.astype('uint16')
self.maxima = maxima.astype('uint16')
self.labs = labs.astype('uint16')
self.watershed_output = watershed_output.astype('uint16')
self.filled_cells = filled_cells.astype('uint16')
self.final_cells = final_cells.astype('uint16')
self.slices = int(self.raw_img.shape[0])
self.height = int(self.raw_img.shape[1])
self.width = int(self.raw_img.shape[2])
self.obj_nums = obj_nums.tolist()
self.volumes = volumes
self.volumes_flag = 'pixels'
self.pdout = ['volumes']
def __repr__(self):
return 'CellSegmentObj '+ self.filename
## PLOTTING METHODS ##
def plot_raw_img(self,display = False):
self.plot_stack(self.raw_img, colormap='gray')
if display:
plt.show()
def plot_gaussian_img(self, display = False):
self.plot_stack(self.gaussian_img, colormap = 'gray')
if display:
plt.show()
def plot_threshold_img(self, display = False):
self.plot_stack(self.threshold_img, colormap = 'gray')
if display:
plt.show()
def plot_filled_img(self, display = False):
self.plot_stack(self.filled_img, colormap = 'gray')
if display:
plt.show()
def plot_dist_map(self, display = False):
self.plot_stack(self.dist_map)
if display:
plt.show()
def plot_smooth_dist_map(self, display = False):
self.plot_stack(self.smooth_dist_map)
if display:
plt.show()
def plot_maxima(self, display = False):
vis_maxima = binary_dilation(self.maxima,
structure = np.ones(shape = (1,5,5)))
masked_maxima = np.ma.masked_where(vis_maxima == 0, vis_maxima)
self.plot_maxima_stack(masked_maxima, self.smooth_dist_map)
if display:
plt.show()
def plot_watershed(self, display = False):
self.plot_stack(self.watershed_output)
if display:
plt.show()
def plot_filled_cells(self, display = False):
self.plot_stack(self.filled_cells)
if display:
plt.show()
def plot_final_cells(self, display = False):
self.plot_stack(self.final_cells)
if display:
plt.show()
# OUTPUT METHODS
def to_csv(self, output_dir = None):
os.chdir(self.f_directory)
if output_dir == None:
output_dir = self.f_directory + '/' + self.filename[0:self.filename.index('.tif')]
if not os.path.isdir(output_dir):
print('creating output directory...')
os.mkdir(output_dir)
os.chdir(output_dir)
for_csv = self.to_pandas()
for_csv.to_csv(path_or_buf = output_dir + '/' +
self.filename[0:self.filename.index('.tif')]+'.csv',
index = True, header = True)
def output_image(self, imageattr, output_dir = None):
os.chdir(self.f_directory)
if output_dir == None:
output_dir = self.f_directory + '/' + self.filename[0:self.filename.index('.tif')]
if not os.path.isdir(output_dir):
print('creating output directory...')
os.mkdir(output_dir)
os.chdir(output_dir)
print('writing image' + str(imageattr))
io.imsave(str(imageattr)+self.filename, getattr(self,str(imageattr)))
def output_all_images(self, output_dir = None):
'''Write all images to a new subdirectory.
Write all images associated with the CellSegmentObj to a new
directory. Name that directory according to the filename of the initial
image that the object was derived from. This new directory should be a
subdirectory to the directory containing the original raw image.
'''
os.chdir(self.f_directory)
if output_dir == None:
output_dir = self.f_directory + '/' + self.filename[0:self.filename.index('.tif')]
if not os.path.isdir(output_dir):
print('creating output directory...')
os.mkdir(output_dir)
os.chdir(output_dir)
print('writing images...')
io.imsave('raw_'+self.filename, self.raw_img)
io.imsave('gaussian_'+self.filename, self.gaussian_img)
io.imsave('threshold_'+self.filename, self.threshold_img)
io.imsave('filled_threshold_'+self.filename, self.filled_img)
io.imsave('dist_'+self.filename, self.dist_map)
io.imsave('smooth_dist_'+self.filename,self.smooth_dist_map)
io.imsave('maxima_'+self.filename,self.maxima)
io.imsave('wshed_'+self.filename,self.watershed_output)
io.imsave('filled_cells_'+self.filename, self.filled_cells)
io.imsave('final_cells_'+self.filename, self.final_cells)
def output_plots(self):
'''Write PDFs of slice-by-slice plots.
Output: PDF plots of each image within CellSegmentObj in a directory
named for the original filename they were generated from. Plots are
generated using the plot_stack method and plotting methods defined
here.
'''
os.chdir(self.f_directory)
if not os.path.isdir(self.f_directory + '/' +
self.filename[0:self.filename.index('.tif')]):
print('creating output directory...')
os.mkdir(self.f_directory + '/' +
self.filename[0:self.filename.index('.tif')])
os.chdir(self.f_directory + '/' +
self.filename[0:self.filename.index('.tif')])
print('saving plots...')
self.plot_raw_img()
plt.savefig('praw_'+self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_gaussian_img()
plt.savefig('pgaussian_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_threshold_img()
plt.savefig('pthreshold_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_filled_img()
plt.savefig('pfilled_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_dist_map()
plt.savefig('pdist_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_smooth_dist_map()
plt.savefig('psmooth_dist_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_maxima()
plt.savefig('pmaxima_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_watershed()
plt.savefig('pwshed_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_filled_cells()
plt.savefig('pfilled_cells_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
self.plot_final_cells()
plt.savefig('pfinal_cells_' +
self.filename[0:self.filename.index('.tif')]+'.pdf')
def pickle(self, output_dir = None):
'''pickle the CellSegmentObj for later loading.'''
if output_dir == None:
output_dir = self.f_directory + '/' + self.filename[0:self.filename.index('.tif')]
if not os.path.isdir(output_dir):
print('creating output directory...')
os.mkdir(output_dir)
os.chdir(output_dir)
with open('pickled_' +
self.filename[0:self.filename.index('.tif')] +
'.pickle', 'wb') as f:
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
f.close()
def output_all(self):
os.chdir(self.f_directory)
if not os.path.isdir(self.f_directory + '/' +
self.filename[0:self.filename.index('.tif')]):
os.mkdir(self.f_directory + '/' +
self.filename[0:self.filename.index('.tif')])
os.chdir(self.f_directory + '/' +
self.filename[0:self.filename.index('.tif')])
print('outputting all data...')
self.output_plots()
self.output_all_images()
self.pickle()
# TODO: UPDATE THIS METHOD TO INCLUDE PANDAS OUTPUT
## HELPER METHODS ##
def to_pandas(self):
'''create a pandas DataFrame of tabulated numeric data.
the pdout attribute indicates which variables to include in the
DataFrame.
'''
df_dict = {}
for attr in self.pdout:
df_dict[str(attr)] = pd.Series(getattr(self, attr))
if 'volumes' in self.pdout:
vflag_out = dict(zip(self.obj_nums,
[self.volumes_flag]*len(self.obj_nums)))
df_dict['volumes_flag'] = pd.Series(vflag_out)
return pd.DataFrame(df_dict)
def convert_volumes(self, z = 0.2, x = 0.0675):
'''convert volumes from units of pixels to metric units.
args:
z: the distance between slices in the z-stack in units of microns.
defaults to commonly used setting for Murray 100x objective.
x: the linear distance between adjacent pixels in each slice in
units of microns. x is also used for y. Defaults to appropriate
value for images acquired on the Murray 100X TIRF objective.
output: converts self.volumes to units of femtoliters, and changes the
self.volumes_flag to 'femtoliters'.
'''
conv_factor = z*x*x
for key, val in self.volumes:
self.volumes[key] = float(self.volumes[key])*conv_factor
self.volumes_flag = 'femtoliters'
def plot_stack(self, stack_arr, colormap='jet'):
''' Create a matplotlib plot with each subplot containing a slice.
Keyword arguments:
stack_arr: a numpy ndarray containing pixel intensity values.
colormap: the colormap to be used when displaying pixel
intensities. defaults to jet.
Output: a pyplot object in which each slice from the image array
is represented in a subplot. subplots are 4 columns
across (when 4 or more slices are present) with rows to
accommodate all slices.
'''
nimgs = stack_arr.shape[0] # z axis of array dictates number of slices
# plot with 4 imgs across
# determine how many rows and columns of images there are
if nimgs < 5:
f, axarr = plt.subplots(1,nimgs)
for i in range(0,nimgs):
axarr[i].imshow(stack_arr[i,:,:], cmap=colormap)
axarr[i].xaxis.set_visible(False)
axarr[i].yaxis.set_visible(False)
f.set_figwidth(16)
f.set_figheight(4)
f.show() # TODO: IMPLEMENT OPTIONAL SAVING OF THE PLOT
else:
f, axarr = plt.subplots(int(np.ceil(nimgs/4)),4)
for i in range(0,nimgs):
r = int(np.floor(i/4))
c = int(i % 4)
axarr[r,c].imshow(stack_arr[i,:,:], cmap=colormap)
axarr[r,c].xaxis.set_visible(False)
axarr[r,c].yaxis.set_visible(False)
if nimgs%4 > 0:
r = int(np.floor(nimgs/4))
for c in range(nimgs%4,4):
axarr[r,c].axis('off')
f.set_figwidth(16)
f.set_figheight(4*np.ceil(nimgs/4))
f.show() # TODO: IMPLEMENT OPTIONAL SAVING OF THE PLOT
def plot_maxima_stack(self, masked_max, smooth_dist):
''' Creates a matplotlib plot object in which each slice from the image
is displayed as a single subplot, in a 4-by-n matrix (n depends upon
the number of slices in the image)'''
nimgs = masked_max.shape[0] # z axis of array dictates number of slices
# plot with 4 imgs across
# determine how many rows and columns of images there are
if nimgs < 5:
f, axarr = plt.subplots(1,nimgs)
for i in range(0,nimgs):
axarr[i].imshow(smooth_dist[i,:,:], cmap='gray')
axarr[i].imshow(masked_max[i,:,:], cmap='autumn')
axarr[i].xaxis.set_visible(False)
axarr[i].yaxis.set_visible(False)
f.set_figwidth(16)
f.set_figheight(4)
f.show() # TODO: IMPLEMENT OPTIONAL SAVING OF THE PLOT
else:
f, axarr = plt.subplots(int(np.ceil(nimgs/4)),4)
for i in range(0,nimgs):
r = int(np.floor(i/4))
c = int(i%4)
axarr[r,c].imshow(smooth_dist[i,:,:], cmap='gray')
axarr[r,c].imshow(masked_max[i,:,:], cmap='autumn')
axarr[r,c].xaxis.set_visible(False)
axarr[r,c].yaxis.set_visible(False)
if nimgs%4 > 0:
r = int(np.floor(nimgs/4))
for c in range(nimgs%4, 4):
axarr[r,c].axis('off')
f.set_figwidth(16)
f.set_figheight(4*np.ceil(nimgs/4))
f.show() # TODO: IMPLEMENT OPTIONAL SAVING OF THE PLOT
## RESTRUCTURING METHODS ##
def slim(self):
'''remove all of the processing intermediates from the object, leaving
only the core information required for later analysis. primarily
intended for use when doing batch analysis of multiple images, and
combining CellSegmentObj instances with instances of other types of
objects segmented in a different fluorescence channel.
'''
del self.raw_img
del self.gaussian_img
del self.threshold_img
del self.filled_img
del self.dist_map
del self.smooth_dist_map
del self.maxima
del self.labs
del self.watershed_output
del self.filled_cells
return self
class CellSegmenter:
def __init__(self, filename, threshold):
self.filename = filename
self.threshold = threshold
def segment(self):
# start timing
starttime = time.time()
# DATA IMPORT AND PREPROCESSING
f_directory = os.getcwd()
print('reading ' + self.filename + ' ...')
raw_img = io.imread(self.filename)
print('raw image imported.')
# next step's gaussian filter assumes 100x obj and 0.2 um slices
print('performing gaussian filtering...')
gaussian_img = gaussian_filter(input=raw_img, sigma=(1, 2, 2))
print('cytosolic image smoothed.')
print('preprocessing complete.')
# BINARY THRESHOLDING AND IMAGE CLEANUP
print('thresholding...')
threshold_img = np.copy(gaussian_img)
threshold_img[threshold_img < self.threshold] = 0
threshold_img[threshold_img > 0] = 1
print('thresholding complete.')
print('filling holes...')
filled_img = binary_fill_holes(threshold_img)
print('3d holes filled.')
print('binary processing complete.')
# DISTANCE AND MAXIMA TRANFORMATIONS TO FIND CELLS
# next two steps assume 100x obj and 0.2 um slices
print('generating distance map...')
dist_map = distance_transform_edt(filled_img, sampling=(2, 1, 1))
print('distance map complete.')
print('smoothing distance map...')
smooth_dist = gaussian_filter(dist_map, [2, 4, 4])
print('distance map smoothed.')
print('identifying maxima...')
max_strel_3d = generate_binary_structure(3, 2)
maxima = maximum_filter(smooth_dist,
footprint=max_strel_3d) == smooth_dist
# clean up background/edges
bgrd_3d = smooth_dist == 0
eroded_background_3d = binary_erosion(bgrd_3d, structure=max_strel_3d,
border_value=1)
maxima = np.logical_xor(maxima, eroded_background_3d)
print('maxima identified.')
# WATERSHED SEGMENTATION
labs = self.watershed_labels(maxima)
print('watershedding...')
cells = watershed(-smooth_dist, labs, mask=filled_img)
print('raw watershedding complete.')
print('filling 2d holes in cells...')
filled_cells = self.fill_cells_2d(cells)
print('2d hole-filling complete.')
print('cleaning up cells...')
clean_cells = self.reassign_pixels_3d(filled_cells)
print('cell cleanup complete.')
print('SEGMENTATION OPERATION COMPLETE.')
endtime = time.time()
runningtime = endtime - starttime
print('time elapsed: ' + str(runningtime) + ' seconds')
cell_nums, volumes = np.unique(clean_cells, return_counts=True)
cell_nums.astype('uint16')
volumes.astype('uint16')
volumes = dict(zip(cell_nums, volumes))
del volumes[0]
cell_nums = cell_nums[np.nonzero(cell_nums)]
return CellSegmentObj(f_directory, self.filename, raw_img,
gaussian_img, self.threshold,
threshold_img, filled_img, dist_map,
smooth_dist, maxima, labs, cells, filled_cells,
clean_cells, cell_nums, volumes)
def watershed_labels(self, maxima_img):
'''Takes a boolean array with maxima labeled as true pixels
and returns an array with maxima numbered sequentially.'''
max_z, max_y, max_x = np.nonzero(maxima_img)
label_output = np.zeros(maxima_img.shape)
for i in range(0,len(max_y)):
label_output[max_z[i],max_y[i],max_x[i]] = i+1
return(label_output)
def fill_cells_2d(self, cell_img):
'''Go cell-by-cell in a watershed-segmented image.
First, make a new 3D array that just includes that cell; then,
fill holes within that cell in each 2D plane. Then add this cell
back into a new 3D array with the same numbering scheme that the
original watershedded image had.'''
#initialize output image for adding cell data in as it's generated.
filled_cells = np.zeros(shape = cell_img.shape)
cell_img.astype(int)
cells = np.unique(cell_img)
ncells = len(cells)
nplanes = cell_img.shape[0]
for i in range(1,ncells):
c_cell = cell_img == i
for z in range(0,nplanes):
c_cell[z,:,:] = binary_fill_holes(c_cell[z,:,:])
filled_cells[c_cell == 1] = i
print('cell ' + str(i) + ' of '
+ str(ncells-1) + ' done.')
return filled_cells
def reassign_pixels_3d(self, filled_cells):
'''Uses reassign_pixels_2d on each slice of a watershed-
segmented image to clean up cells.
'''
reassigned_cells = np.ndarray(shape = filled_cells.shape)
for i in range(0,filled_cells.shape[0]):
print(' beginning slice ' + str(i) + '...')
reassigned_cells[i,:,:] = self.reassign_pixels_2d(filled_cells[i,:,:])
print(' slice ' + str(i) + ' complete.')
return reassigned_cells
def reassign_pixels_2d(self, watershed_img_2d):
'''In 2D, go over each segmented pixel in the image with a
structuring element. Measure the frequency of each different
pixel value that occurs within this structuring element (i.e. 1
for one segmented cell, 2 for another, etc. etc.). Generate a
new image with each pixel value set to that of the most common
non-zero (non-background) pixel within its proximity. After
doing this for every pixel, check and see if the new image is
identical to the old one. If it is, stop; if not, repeat with
the new image as the starting image.'''
#retrieve number of cells in the watershedded image
ncells = np.unique(watershed_img_2d).size-1
strel = np.array(
[[0,0,0,1,0,0,0],
[0,0,1,1,1,0,0],
[0,1,1,1,1,1,0],
[1,1,1,1,1,1,1],
[0,1,1,1,1,1,0],
[0,0,1,1,1,0,0],
[0,0,0,1,0,0,0]])
old_img = watershed_img_2d
mask_y,mask_x = np.nonzero(old_img)
tester = 0
counter = 0
while tester == 0:
new_img = np.zeros(old_img.shape)
for i in range(0,len(mask_y)):
y = mask_y[i]
x = mask_x[i]
# shift submatrix if the pixel is too close to the edge to be
# centered within it
if y-3 < 0:
y = 3
if y+4 >= old_img.shape[0]:
y = old_img.shape[0] - 4
if x-3 < 0:
x = 3
if x+4 >= old_img.shape[1]:
x = old_img.shape[1]-4
# take a strel-sized matrix around the pixel of interest.
a = old_img[y-3:y+4,x-3:x+4]
# mask for the pixels that I'm interested in comparing to
svals = np.multiply(a,strel)
# convert this to a 1D array with zeros removed
cellvals = svals.flatten()[np.nonzero(
svals.flatten())].astype(int)
# count number of pixels that correspond to each
# cell in proximity
freq = np.bincount(cellvals,minlength=ncells)
# find cell with most abundant pixels in vicinity
top_cell = np.argmax(freq)
# because argmax only returns the first index if there are
# multiple matches, I need to check for duplicates. if
# there are duplicates, I'm going to leave the value as the
# original, regardless of whether the set of most abundant
# pixel values contain the original pixel (likely on an
# edge) or not (likely on a point where three cells come
# together - this is far less common)
if np.sum(freq==freq[top_cell]) > 1:
new_img[mask_y[i],mask_x[i]] = old_img[mask_y[i],mask_x[i]]
else:
new_img[mask_y[i],mask_x[i]] = top_cell
if np.array_equal(new_img,old_img):
tester = 1
else:
old_img = np.copy(new_img)
counter += 1
print(' number of iterations = ' + str(counter))
return new_img
if __name__ == '__main__':
''' Run segmentation on all images in the working directory.
sys.argv contents:
threshold: a number that will be used as the threshold to binarize
the cytosolic fluorescence image.
'''
threshold = int(sys.argv[1])
wd_contents = os.listdir(os.getcwd())
imlist = []
for f in wd_contents:
if (f.endswith('.tif') or f.endswith('.tiff')
or f.endswith('.TIF') or f.endswith('.TIFF')):
imlist.append(f)
print('final imlist:')
print(imlist)
for i in imlist:
print('initializing segmenter object...')
i_parser = CellSegmenter(i,threshold)
print('initializing segmentation...')
i_obj = i_parser.segment()
i_obj.output_all()
# TODO LIST:
# ADD LOG ATTRIBUTE APPEND COMMANDS IN METHODS WHERE IT'S NOT IMPLEMENTED
# STOP THE PLOTS FROM BEING SHOWN WHEN THEY SHOULDN'T