import torch
import numpy as np
import util
def measure_mean_potential(pc: torch.Tensor):
grid = util.gen_grid().to(pc.device)
return potential(pc, grid).mean()
def potential(sources, means, eps=1e-5, recursive=True, max_pts=15000):
"""
Calculate dipole potential
Args:
sources: position and dipole moments for the fields sources
means: positions to calculate potential at
Returns:
torch.Tensor meansX3 field at measurement positions
"""
if recursive:
def break_by_means():
mid = int(means.shape[0] / 2)
return torch.cat([potential(sources, means[:mid], eps, recursive, max_pts),
potential(sources, means[mid:], eps, recursive, max_pts)], dim=0)
def break_by_sources():
mid = int(sources.shape[0] / 2)
return potential(sources[:mid], means, eps, recursive, max_pts) \
+ potential(sources[mid:], means, eps, recursive, max_pts)
if sources.shape[0] > max_pts and means.shape[0] > max_pts:
if sources.shape[0] > means.shape[0]:
return break_by_sources()
else:
return break_by_means()
if sources.shape[0] > max_pts:
return break_by_sources()
if means.shape[0] > max_pts:
return break_by_means()
p = sources[:, 3:]
R = sources[:, None, :3] - means[None, :, :3]
phi = (p[:, None, :] * R).sum(dim=-1)
phi = phi / (R.norm(dim=-1) ** 3)[:, :]
phi_total = phi.sum(dim=0)
phi_total[phi_total.isinf()] = 0
phi_total[phi_total.isnan()] = 0
return phi_total
def field_grad(sources, means, eps=1e-5, recursive=True, max_pts=15000):
"""
Calculate dipole field i.e. potential gradient
Args:
sources: position and dipole moments for the fields sources
means: positions to calculate gradient at
Returns:
torch.Tensor meansX3 field at measurement positions
"""
if recursive:
def break_by_means():
mid = int(means.shape[0] / 2)
return torch.cat([field_grad(sources, means[:mid], eps, recursive, max_pts),
field_grad(sources, means[mid:], eps, recursive, max_pts)], dim=0)
def break_by_sources():
mid = int(sources.shape[0] / 2)
return field_grad(sources[:mid], means, eps, recursive, max_pts) \
+ field_grad(sources[mid:], means, eps, recursive, max_pts)
if sources.shape[0] > max_pts and means.shape[0] > max_pts:
if sources.shape[0] > means.shape[0]:
return break_by_sources()
else:
return break_by_means()
if sources.shape[0] > max_pts:
return break_by_sources()
if means.shape[0] > max_pts:
return break_by_means()
p = sources[:, 3:]
R = sources[:, None, :3] - means[None, :, :3]
R_unit = R / R.norm(dim=-1)[:, :, None]
E = 3 * (p[:, None, :] * R_unit).sum(dim=-1)[:, :, None] * R_unit - p[:, None, :]
E = E / (R.norm(dim=-1) ** 3 + eps)[:, :, None]
E_total = E.sum(dim=0) * -1 # field=(-1)*grad -> flip the sign to get the gradient instead of -gradient
E_total[E_total.isinf()] = 0
E_total[E_total.isnan()] = 0
return E_total
def reference_field(pc1, pc2):
with torch.no_grad():
E = field_grad(pc1, pc2, recursive=True)
if pc2.shape[1] == 3:
length = E.norm(dim=-1)
E[length != 0, :] = E[length != 0, :] / length[length != 0, None]
pc2 = torch.cat([pc2, E], dim=1)
else:
interactions = E * pc2[:, 3:]
interactions = interactions.sum(dim=-1)
sign = (interactions >= 0).float() * 2 - 1
pc2[:, 3:] = pc2[:, 3:] * sign[:, None]
return pc2
def strongest_field_propagation_reps(input_pc, reps, diffuse=False, weights=None):
input_pc = input_pc.detach()
with torch.no_grad():
pts = input_pc
if weights is not None:
# factor in the weights for each point by scaling the normals
weights = weights.clamp(0.1, 1)
pts[:, 3:] = pts[:, 3:] * weights[:, None]
device = input_pc.device
remaining = []
E = torch.zeros_like(pts[:, :3])
patches = []
oriented_pts_mask = torch.zeros(len(pts)).bool()
non_oriented_pts_mask = torch.zeros(len(pts)).bool()
for rep, rest in reps:
remaining.append((rep, rest))
patches.append(rep)
non_oriented_pts_mask[rep] = True
# find the flattest patch to start with
curv = [util.pca_eigen_values(pts[patch]) for patch in patches]
min_index = np.array([curv[i][0] for i in range(len(patches))])
min_index = np.abs(min_index)
min_index = np.argmin(min_index)
# calculate the field from the initial patch
start_patch, rest_patch = remaining.pop(min_index)
oriented_pts_mask[start_patch] = True
non_oriented_pts_mask[start_patch] = False
E[non_oriented_pts_mask] = field_grad(pts[oriented_pts_mask], pts[non_oriented_pts_mask])
# prop orientation as long as there are remaining unoriented patches
while len(remaining) > 0:
# calculate the interaction between the field and all remaining patches
interaction = [(E[patch] * pts[patch, 3:]).sum(dim=-1).sum() for patch, rest in remaining]
# orient the patch with the strongest interaction
max_interaction_index = torch.tensor(interaction).abs().argmax().item()
patch, rest_patch = remaining.pop(max_interaction_index)
# print(f'{patch_index}')
if interaction[max_interaction_index] < 0:
pts[patch, 3:] *= -1
pts[rest_patch, 3:] *= -1
oriented_pts_mask[patch] = True
non_oriented_pts_mask[patch] = False
if diffuse:
# add the effect of the current patch to *all* other patches
patch_mask = torch.logical_or(oriented_pts_mask, non_oriented_pts_mask)
patch_mask[patch] = False
dE = field_grad(pts[patch], pts[patch_mask])
E[patch_mask] = E[patch_mask] + dE
else:
# add the effect of the current patch only to the *remaining* patches
dE = field_grad(pts[patch], pts[non_oriented_pts_mask])
E[non_oriented_pts_mask] = E[non_oriented_pts_mask] + dE
if diffuse:
for rep, rest in reps:
interactions = (E[rep] * pts[rep, 3:]).sum(dim=-1)
sign = (interactions > 0).float() * 2 - 1
pts[rep, 3:] = pts[rep, 3:] * sign[:, None]
E = field_grad(pts[oriented_pts_mask], pts[~oriented_pts_mask])
interactions = (E * pts[~oriented_pts_mask, 3:]).sum(dim=-1)
sign = (interactions > 0).float() * 2 - 1
pts[~oriented_pts_mask, 3:] = pts[~oriented_pts_mask, 3:] * sign[:, None]
pts = pts.to(device)
if weights is not None:
# scale the normal back to unit because of previous weighted scaling
pts[:, 3:] = pts[:, 3:] / weights[:, None]
def strongest_field_propagation(pts, patches, all_patches, diffuse=False, weights=None):
with torch.no_grad():
if weights is not None:
# factor in the weights for each point by scaling the normals
weights = weights.clamp(0.1, 1)
pts[:, 3:] = pts[:, 3:] * weights[:, None]
device = pts.device
# initialize remaining
remaining = []
pts_mask = torch.zeros(len(pts)).bool()
zeros = torch.zeros(len(pts)).bool()
E = torch.zeros_like(pts[:, :3])
for i in range(len(all_patches)):
remaining.append((i, all_patches[i]))
# find the flattest patch to start with
curv = [util.pca_eigen_values(pts[patch]) for patch in all_patches]
min_index = np.array([curv[i][0] for i in range(len(all_patches))])
min_index = np.abs(min_index)
min_index = np.argmin(min_index)
# calculate the field from the initial patch
_, start_patch = remaining.pop(min_index)
pts_mask[start_patch] = True
E[~pts_mask] = field_grad(pts[pts_mask], pts[~pts_mask])
# prop orientation as long as there are remaining unoriented patches
while len(remaining) > 0:
# calculate the interaction between the field and all remaining patches
interaction = [(E[patch] * pts[patch, 3:]).sum(dim=-1).sum() for i, patch in remaining]
# orient the patch with the strongest interaction
max_interaction_index = torch.tensor(interaction).abs().argmax().item()
patch_index, patch = remaining.pop(max_interaction_index)
# print(f'{patch_index}')
if interaction[max_interaction_index] < 0:
pts[patch, 3:] *= -1
pts_mask[patch] = True
if diffuse:
# add the effect of the current patch to *all* other patches
patch_mask = zeros.clone()
patch_mask[patch] = True
dE = field_grad(pts[patch], pts[~patch_mask])
E[~patch_mask] = E[~patch_mask] + dE
else:
# add the effect of the current patch only to the *remaining* patches
dE = field_grad(pts[patch], pts[~pts_mask])
E[~pts_mask] = E[~pts_mask] + dE
if diffuse:
for patch in patches:
patch = patch[1]
interactions = (E[patch] * pts[patch, 3:]).sum(dim=-1)
sign = (interactions > 0).float() * 2 - 1
pts[patch, 3:] = pts[patch, 3:] * sign[:, None]
pts = pts.to(device)
if weights is not None:
# scale the normal back to unit because of previous weighted scaling
pts[:, 3:] = pts[:, 3:] / weights[:, None]
def strongest_field_propagation_points(pts: torch.Tensor, diffuse=False, starting_point=0):
device = pts.device
pts = pts.cuda()
indx = torch.arange(pts.shape[0]).to(pts.device)
E = torch.zeros_like(pts[:, :3])
visited = torch.zeros_like(pts[:, 0]).bool()
visited[starting_point] = True
E[~(indx == starting_point)] += field_grad(pts[starting_point:(starting_point + 1)],
pts[~(indx == starting_point), :3], eps=1e-6)
# prop orientation as long as there are remaining unoriented points
while not visited.all():
# calculate the interaction between the field and all remaining patches
interaction = (E[~visited] * pts[~visited, 3:]).sum(dim=-1)
# orient the patch with the strongest interaction
max_interaction_index = interaction.abs().argmax()
pts_index = indx[~visited][max_interaction_index]
# print(f'{patch_index}')
if interaction[max_interaction_index] < 0:
pts[pts_index, 3:] *= -1
visited[pts_index] = True
E[~(indx == pts_index)] += field_grad(pts[pts_index:(pts_index + 1)],
pts[~(indx == pts_index), :3], eps=1e-6)
if diffuse:
interactions = (E * pts[:, 3:]).sum(dim=-1)
sign = (interactions > 0).float() * 2 - 1
pts[:, 3:] = pts[:, 3:] * sign[:, None]
pts = pts.to(device)