https://github.com/lauziere/MHHT
Tip revision: f7e35a2e3ef398191b9e49a57f80d514d8f880c1 authored by lauziere on 01 June 2022, 00:33:05 UTC
Update README.md
Update README.md
Tip revision: f7e35a2
util.py
import numpy as np
import os
import time
import pandas as pd
from murty import *
from MSC import *
from BTP import *
from plotting import *
import pdb
def Embryo_graph(n):
adj = np.zeros((n,n),'int')
for i in range(n):
for j in range(n):
if i-1 == j:
adj[i,j]=1
if i-2==j:
adj[i,j]=1
if i%2:
if i-3==j:
adj[i,j]=1
adj[-1,-2]=1
edge_list = np.transpose(np.nonzero(adj))
return adj, edge_list
def Embryo(in_arr, out_arr, edge_list):
m = edge_list.shape[0]
starts = edge_list[:,0]
ends = edge_list[:,1]
in_starts = in_arr[starts]
in_ends = in_arr[ends]
out_starts = out_arr[starts]
out_ends = out_arr[ends]
in_gaps = in_ends - in_starts
in_lengths = np.linalg.norm(in_gaps, axis=1)
out_gaps = out_ends - out_starts
out_lengths = np.linalg.norm(out_gaps, axis=1)
cost = np.abs(out_lengths - in_lengths).sum()
return cost
def Posture(in_arr, out_arr, edge_list, inv_cov):
m = edge_list.shape[0]
starts = edge_list[:,0]
ends = edge_list[:,1]
in_starts = in_arr[starts]
in_ends = in_arr[ends]
out_starts = out_arr[starts]
out_ends = out_arr[ends]
in_gaps = in_ends - in_starts
in_lengths = np.linalg.norm(in_gaps, axis=1)
out_gaps = out_ends - out_starts
out_lengths = np.linalg.norm(out_gaps, axis=1)
diffs = out_lengths - in_lengths
cost_sq = np.linalg.multi_dot([diffs.T, inv_cov, diffs])
cost = cost_sq**.5
return cost
def Movement(in_arr, out_arr, inv_cov):
# d1 and d0 are ordered.
diffs = out_arr - in_arr
diffs_vec = diffs.flatten()
cost_sq = np.linalg.multi_dot([diffs_vec.T, inv_cov, diffs_vec])
cost = cost_sq**.5
return cost
def PostureMovement(in_arr, out_arr, edge_list, inv_cov1, inv_cov2):
posture = Posture(in_arr, out_arr, edge_list, inv_cov1)
movement = Movement(in_arr, out_arr, inv_cov2)
posture_movement = posture + movement
return posture_movement
def Intersection_test(data, skinnyFactor=0.55):
numCells = 18
numSegs = int(np.floor(numCells/2) - 1)
origALL = np.empty((8*numSegs**2, 3))
vecDirALL = np.empty((8*numSegs**2, 3))
relComps = np.transpose(np.abs(np.diff(np.array([np.repeat(np.arange(numSegs), 8*numSegs), np.tile(np.repeat(np.repeat(np.arange(numSegs), 1), numSegs), numSegs)]), axis=0))>=2)[:,0]
coords = data[:numCells].copy()
coordsL = coords[::2]
coordsR = coords[1::2]
vertCoordsL = coordsL - (coordsL - coordsR)*skinnyFactor / 2
vertCoordsR = coordsR - (coordsR - coordsL)*skinnyFactor / 2
vertCoordsLR = np.empty((numCells, 3))
vertCoordsLR[::2] = vertCoordsL
vertCoordsLR[1::2] = vertCoordsR
vert0Seq = np.repeat(vertCoordsL[:numSegs], numSegs, axis=0)
vert1Seq = np.repeat(vertCoordsR[:numSegs], numSegs, axis=0)
vert2Seq = np.repeat(vertCoordsLR[2:],numSegs/2, axis=0)
vecDirLR = coordsR-coordsL
vecDirAPL = np.diff(coordsL, axis=0)
vecDirAPR = np.diff(coordsR, axis=0)
origALL[::2, :] = np.repeat(coordsL[:numSegs, :], 8*numSegs/2, axis=0)
origALL[1::4, :] = np.repeat(coordsR[:numSegs, :], 8*numSegs/4, axis=0)
origALL[3::4, :] = np.repeat(coordsL[1:numSegs+1, :], 8*numSegs/4, axis=0)
vecDirALL[::4, :] = np.repeat(vecDirAPL[:numSegs, :], 8*numSegs/4, axis=0)
vecDirALL[1::4, :] = np.repeat(vecDirAPR[:numSegs, :], 8*numSegs/4, axis=0)
vecDirALL[2::4, :] = np.repeat(vecDirLR[:numSegs, :], 8*numSegs/4, axis=0)
vecDirALL[3::4, :] = np.repeat(vecDirLR[1:numSegs+1, :], 8*numSegs/4, axis=0)
vert0ALL = np.tile(vert0Seq, (numSegs, 1))
vert1ALL = np.tile(vert1Seq, (numSegs, 1))
vert2ALL = np.tile(vert2Seq, (numSegs, 1))
orig = origALL[relComps,:]
vecDir = vecDirALL[relComps,:]
vert0 = vert0ALL[relComps,:]
vert1 = vert1ALL[relComps,:]
vert2 = vert2ALL[relComps,:]
edge1 = vert1-vert0
edge2 = vert2-vert0
tve = orig-vert0
pve = np.cross(vecDir, edge2, axis=1)
det1 = np.sum(edge1*pve,axis=1)
angOK = np.abs(det1)>0
det1[np.invert(angOK)] = np.inf
u = np.sum(tve*pve, axis=1)/det1
v = np.inf + np.zeros(u.shape)
t = v.copy()
ok = np.logical_and.reduce((angOK, u>=0, u<=1.0))
if np.invert(np.any(ok)):
intersections = ok
else:
qve = np.cross(tve[ok,:], edge1[ok,:], axis=1)
v[ok] = np.sum(vecDir[ok,:]*qve, axis=1)/det1[ok]
t[ok] = np.sum(edge2[ok,:]*qve, axis=1)/det1[ok]
ok = np.logical_and.reduce((ok, v >= 0, u+v <= 1.0))
intersections = np.logical_and.reduce((ok, t >= 0, t <= 1.0))
int_cost = 1e6*intersections.sum()
return int_cost
class MHHT:
def __init__(self, config):
# Read in config
self.Dataset = config['Dataset']
self.Q = config['Q']
self.Detection = config['Detection']
self.Interpolation = config['Interpolation']
self.Interpolation_cost = config['Interpolation_cost']
self.cost_threshold = config['Cost_threshold']
self.print_interval = config['Print_interval']
self.d = config['d']
self.Model = config['Model']
self.K = config['K']
self.N = config['N']
self.Solver = config['Solver']
self.StartFrame = config['StartFrame']
self.EndFrame = config['EndFrame']
self.InitialFrame = config['InitialFrame']
self.CurrentFrame = self.StartFrame
self.out_path = os.path.join(config['Home'], 'Results', self.Dataset)
self.track_path = os.path.join(self.out_path, 'output')
os.mkdir(self.out_path) if not os.path.exists(self.out_path) else None
self.output_name = 'Tracks_' + str(self.StartFrame) + '_' + str(self.EndFrame) + '_' + self.Model + '_' + \
str(self.K) + '_' + str(self.N) + '_' + self.Detection + '_' + self.Interpolation + '_' + \
str(self.d) + '.npy'
self.n = 21 if self.Q else 19
# Build graph:
self.adj, self.edge_list = Embryo_graph(self.n)
# Initialize arrays
self.Annotations = np.load(config['Annotation_path'], allow_pickle=True)
self.Predictions = np.load(config['Prediction_path'], allow_pickle=True)
self.Tracks = np.zeros((self.EndFrame - self.StartFrame + 1, self.n, 3))
self.Tracks[0] = self.Annotations[self.StartFrame - self.InitialFrame]
self.Costs = np.zeros(self.EndFrame - self.StartFrame + 1)
self.Posture_weights = np.load(config['Posture_weights_path'])
self.Movement_weights = np.load(config['Movement_weights_path'])
def get_unary_cost(self, out_array_interp, C, out_perm):
costs = C[np.arange(self.n),out_perm]
costs[costs==self.d] = self.Interpolation_cost
intersection_cost = Intersection_test(out_array_interp)
unary_cost = costs.sum() + intersection_cost
return unary_cost
def get_total_assignment_cost(self, j, C, out_perm):
edge_list = self.edge_list
in_arr = self.inter_arrays[j]
out_arr = self.inter_arrays[j+1]
unary_cost = self.get_unary_cost(out_arr, C, out_perm)
if self.Model == 'GNN' or self.Model == 'MHT':
total_assignment_cost = unary_cost
elif self.Model == 'Embryo':
graphical_cost = Embryo(in_arr, out_arr, edge_list)
total_assignment_cost = unary_cost + graphical_cost
elif self.Model == 'Movement':
graphical_cost = Movement(in_arr, out_arr, self.Movement_weights)
total_assignment_cost = unary_cost + graphical_cost
elif self.Model == 'Posture':
graphical_cost = Posture(in_arr, out_arr, edge_list, self.Posture_weights)
total_assignment_cost = unary_cost + graphical_cost
elif self.Model == 'Posture-Movement':
graphical_cost = PostureMovement(in_arr, out_arr, edge_list, self.Posture_weights, self.Movement_weights)
total_assignment_cost = unary_cost + graphical_cost
return total_assignment_cost
def update(self):
j = 0
# These will update
self.n_frames_remaining = self.EndFrame - self.CurrentFrame
self.N_scan = min([self.N, self.n_frames_remaining])
self.path = np.zeros((self.N_scan+1, self.n, 3))
self.final_arrays = np.zeros((self.K, self.n, 3))
self.inter_arrays = np.zeros((self.N_scan+1, self.n, 3))
self.inter_arrays[0] = self.Tracks[self.CurrentFrame - self.StartFrame]
self.final_path = []
self.cost = []
self.order = []
self.final_order = np.zeros(self.n)
self.best_cost = 1e8
self.search(j)
return self.final_arrays[self.final_order[0]], self.best_cost
def track(self):
for t in range(self.StartFrame, self.EndFrame):
st = time.time()
tracks, cost = self.update()
self.CurrentFrame += 1
self.Tracks[t+1 - self.StartFrame] = tracks
self.Costs[t+1 - self.StartFrame] = cost
rt = time.time() - st
print('Frame', self.CurrentFrame)
print('Cost', cost)
print('Runtime:', np.round(rt,2), 'seconds \n')
np.save(os.path.join(self.out_path, self.output_name), self.Tracks)
def track_notebook(self):
print_interval = self.print_interval
print_frames = np.arange(self.StartFrame + print_interval, self.EndFrame + print_interval, print_interval)
for t in range(self.StartFrame, self.EndFrame):
update_progress(self.StartFrame, t+1, self.EndFrame, label="Tracking {} of {}".format(t+1, self.EndFrame))
st = time.time()
tracks, cost = self.update()
rt = time.time() - st
print('Frame', self.CurrentFrame + 1)
print('Cost', cost)
print('Runtime:', np.round(rt,2), 'seconds \n')
if cost >= self.cost_threshold or t+1 in print_frames:
plot_3d_overlay(self.Tracks[t - self.StartFrame], tracks, errors={})
tracks = correction(t+1, self.out_path, self.Tracks[t - self.StartFrame], tracks)
self.CurrentFrame += 1
self.Tracks[t+1 - self.StartFrame] = tracks
self.Costs[t+1 - self.StartFrame] = cost
# Write out tracks.csv
this_track_path = os.path.join(self.track_path, 'tracks_' + str(t) + '.csv')
tracks_df = pd.DataFrame(tracks, columns = ['x','y','z'])
tracks_df.to_csv(this_track_path)
class MHHT_Murty(MHHT):
def __init__(self, config):
MHHT.__init__(self, config)
def interpolate(self, in_array_aug, out_array, chosen_cols):
adj = self.adj
n2 = out_array.shape[0]
found_nuclei = np.argwhere(chosen_cols < n2)[:,0] # of assigned points, which are actual nuclei
missing_nuclei = np.argwhere(chosen_cols >= n2)[:,0] # of assigned points, which are missed nuclei
num_missing = missing_nuclei.shape[0]
if self.Interpolation == 'Last':
# Use last point.
interp_points = {z:in_array_aug[z] for z in missing_nuclei}
elif self.Interpolation == 'Graph':
interp_points = {}
for z in missing_nuclei:
last_pos = in_array_aug[z]
used_nuclei = np.concatenate([np.nonzero(adj[z,:])[0], np.nonzero(adj[:,z])[0]])
used_found_nuclei = np.intersect1d(found_nuclei, used_nuclei)
num_used_found_nuclei = used_found_nuclei.shape[0]
if num_used_found_nuclei > 0:
preds = np.empty((num_used_found_nuclei, 3))
for y in range(num_used_found_nuclei):
nuc = used_found_nuclei[y]
last_nuc = in_array_aug[nuc]
this_nuc = out_array[chosen_cols[nuc]]
this_pred = this_nuc - (last_nuc - last_pos)
preds[y] = this_pred
pred_pos = preds.mean(axis=0)
interp_points[z] = pred_pos
elif num_used_found_nuclei == 0:
interp_points[z] = in_array_aug[z]
# Update out_array
out_array_interp = np.empty((self.n, 3))
for z in range(self.n):
pred_col = chosen_cols[z]
# if missing:
if pred_col >= n2:
# use interpolated.
out_array_interp[z] = interp_points[z]
elif pred_col < n2:
out_array_interp[z] = out_array[pred_col]
return out_array_interp
def search(self, j):
in_array_aug = self.inter_arrays[j]
out_array_aug = self.Predictions[self.CurrentFrame - self.InitialFrame + j + 1]
C = Murty_mat(in_array_aug, out_array_aug, self.d, self.d)
costs, rows_K, cols_K = Murty(C, self.K)
hyps = get_unique_hyps(cols_K, self.n)
num_hyps = hyps.shape[0]
for k in range(num_hyps):
chosen_cols = hyps[k]
out_array_interp = self.interpolate(in_array_aug, out_array_aug, chosen_cols)
self.inter_arrays[j+1] = out_array_interp
if j == 0:
self.final_arrays[k] = out_array_interp
this_step_cost = self.get_total_assignment_cost(j, C, chosen_cols)
self.cost.append(this_step_cost)
self.order.append(k)
j += 1
if j < self.N_scan and sum(self.cost) < self.best_cost:
self.search(j)
if j == self.N_scan and sum(self.cost) < self.best_cost:
fc = sum(self.cost)
if fc < self.best_cost:
self.best_cost = fc
self.final_order = self.order.copy()
j+= -1
self.cost.pop()
self.order.pop()
class MHHT_MSC(MHHT):
def __init__(self, config):
MHHT.__init__(self, config)
def interpolate(self, in_array_aug, out_array, chosen_cols):
adj = self.adj
n2 = out_array.shape[0]
found_nuclei = np.argwhere(chosen_cols < n2)[:,0] # of assigned points, which are actual nuclei
missing_nuclei = np.argwhere(chosen_cols >= n2)[:,0] # of assigned points, which are missed nuclei
num_missing = missing_nuclei.shape[0]
if self.Interpolation == 'Last':
# Use last point.
interp_points = {z:in_array_aug[z] for z in missing_nuclei}
elif self.Interpolation == 'Graph':
interp_points = {}
for z in missing_nuclei:
last_pos = in_array_aug[z]
used_nuclei = np.concatenate([np.nonzero(adj[z,:])[0], np.nonzero(adj[:,z])[0]])
used_found_nuclei = np.intersect1d(found_nuclei, used_nuclei)
num_used_found_nuclei = used_found_nuclei.shape[0]
if num_used_found_nuclei > 0:
preds = np.empty((num_used_found_nuclei, 3))
for y in range(num_used_found_nuclei):
nuc = used_found_nuclei[y]
last_nuc = in_array_aug[nuc]
this_nuc = out_array[chosen_cols[nuc]]
this_pred = this_nuc - (last_nuc - last_pos)
preds[y] = this_pred
pred_pos = preds.mean(axis=0)
interp_points[z] = pred_pos
elif num_used_found_nuclei == 0:
interp_points[z] = in_array_aug[z]
# Update out_array
out_array_interp = np.empty((self.n, 3))
for z in range(self.n):
pred_col = chosen_cols[z]
# if missing:
if pred_col >= n2:
# use interpolated.
out_array_interp[z] = interp_points[z]
elif pred_col < n2:
out_array_interp[z] = out_array[pred_col]
return out_array_interp
def search(self, j):
in_array_aug = self.inter_arrays[j]
out_array_aug = self.Predictions[self.CurrentFrame - self.InitialFrame + j + 1]
C = Murty_mat_MSC(in_array_aug, out_array_aug, self.d)
costs, rows_K, cols_K = Murty_MSC(C, self.K)
# hyps = get_unique_hyps(cols_K, self.n)
# num_hyps = hyps.shape[0]
for k in range(self.K):
chosen_cols = cols_K[k]
out_array_interp = self.interpolate(in_array_aug, out_array_aug, chosen_cols)
self.inter_arrays[j+1] = out_array_interp
if j == 0:
self.final_arrays[k] = out_array_interp
this_step_cost = self.get_total_assignment_cost(j, C, chosen_cols)
self.cost.append(this_step_cost)
self.order.append(k)
j += 1
if j < self.N_scan and sum(self.cost) < self.best_cost:
self.search(j)
if j == self.N_scan and sum(self.cost) < self.best_cost:
fc = sum(self.cost)
if fc < self.best_cost:
self.best_cost = fc
self.final_order = self.order.copy()
j+= -1
self.cost.pop()
self.order.pop()
class MHHT_BTP(MHHT):
def __init__(self, config):
MHHT.__init__(self, config)
def interpolate(self, in_array_aug, out_array, chosen_cols):
adj = self.adj
n2 = out_array.shape[0]
found_nuclei = np.argwhere(chosen_cols != 0)[:,0] # of assigned points, which are actual nuclei
missing_nuclei = np.argwhere(chosen_cols == 0)[:,0] # of assigned points, which are missed nuclei
# Set the values back 1 due to missing nucleus offset:
chosen_cols_copy = chosen_cols.copy()
chosen_cols_copy[found_nuclei] += -1
if self.Interpolation == 'Last':
# Use last point.
interp_points = {z:in_array_aug[z] for z in missing_nuclei}
elif self.Interpolation == 'Graph':
interp_points = {}
for z in missing_nuclei:
last_pos = in_array_aug[z]
used_nuclei = np.concatenate([np.nonzero(adj[z,:])[0], np.nonzero(adj[:,z])[0]])
used_found_nuclei = np.intersect1d(found_nuclei, used_nuclei)
num_used_found_nuclei = used_found_nuclei.shape[0]
if num_used_found_nuclei > 0:
preds = np.empty((num_used_found_nuclei, 3))
for y in range(num_used_found_nuclei):
nuc = used_found_nuclei[y]
last_nuc = in_array_aug[nuc]
this_nuc = out_array[chosen_cols_copy[nuc]]
this_pred = this_nuc - (last_nuc - last_pos)
preds[y] = this_pred
pred_pos = preds.mean(axis=0)
interp_points[z] = pred_pos
elif num_used_found_nuclei == 0:
interp_points[z] = in_array_aug[z]
# Update out_array
out_array_interp = np.empty((self.n, 3))
for z in range(self.n):
if z in found_nuclei:
pred_pos = chosen_cols_copy[z]
out_array_interp[z] = out_array[pred_pos]
elif z in missing_nuclei:
interp_pos = interp_points[z]
out_array_interp[z] = interp_pos
return out_array_interp
def search(self, j):
in_array_aug = self.inter_arrays[j]
out_array_aug = self.Predictions[self.CurrentFrame - self.InitialFrame + j + 1]
n1 = self.n
n2 = out_array_aug.shape[0]
C, A, Aeq, b, beq = build_initial_all(in_array_aug, out_array_aug, self.d)
x, xv = MBest(C, A, b, Aeq, beq, self.K)
y = np.reshape(x, (n1, n2+1, self.K))
y_out = np.argmax(y, axis=1)
for k in range(self.K):
chosen_cols = y_out[:, k]
out_array_interp = self.interpolate(in_array_aug, out_array_aug, chosen_cols)
self.inter_arrays[j+1] = out_array_interp
if j == 0:
self.final_arrays[k] = out_array_interp
C_mat = C.copy().reshape((n1, n2+1))
this_step_cost = self.get_total_assignment_cost(j, C_mat, chosen_cols)
self.cost.append(this_step_cost)
self.order.append(k)
j += 1
if j < self.N_scan and sum(self.cost) < self.best_cost:
self.search(j)
if j == self.N_scan and sum(self.cost) < self.best_cost:
fc = sum(self.cost)
if fc < self.best_cost:
self.best_cost = fc
self.final_order = self.order.copy()
j+= -1
self.cost.pop()
self.order.pop()