https://github.com/galmetzer/dipole-normal-prop
Raw File
Tip revision: 0887b893b153f4ce6b09e4dd485f9b042f15b706 authored by Gal Metzer on 08 September 2021, 21:53 UTC
Update README.md
Tip revision: 0887b89
field_utils.py
import torch
import numpy as np
import util


def measure_mean_potential(pc: torch.Tensor):
    grid = util.gen_grid().to(pc.device)
    return potential(pc, grid).mean()


def potential(sources, means, eps=1e-5, recursive=True, max_pts=15000):
    """
    Calculate dipole potential
    Args:
        sources: position and dipole moments for the fields sources
        means: positions to calculate potential at

    Returns:
        torch.Tensor meansX3 field at measurement positions
    """

    if recursive:
        def break_by_means():
            mid = int(means.shape[0] / 2)
            return torch.cat([potential(sources, means[:mid], eps, recursive, max_pts),
                              potential(sources, means[mid:], eps, recursive, max_pts)], dim=0)

        def break_by_sources():
            mid = int(sources.shape[0] / 2)
            return potential(sources[:mid], means, eps, recursive, max_pts) \
                   + potential(sources[mid:], means, eps, recursive, max_pts)

        if sources.shape[0] > max_pts and means.shape[0] > max_pts:
            if sources.shape[0] > means.shape[0]:
                return break_by_sources()
            else:
                return break_by_means()

        if sources.shape[0] > max_pts:
            return break_by_sources()

        if means.shape[0] > max_pts:
            return break_by_means()

    p = sources[:, 3:]
    R = sources[:, None, :3] - means[None, :, :3]

    phi = (p[:, None, :] * R).sum(dim=-1)
    phi = phi / (R.norm(dim=-1) ** 3)[:, :]
    phi_total = phi.sum(dim=0)

    phi_total[phi_total.isinf()] = 0
    phi_total[phi_total.isnan()] = 0
    return phi_total


def field_grad(sources, means, eps=1e-5, recursive=True, max_pts=15000):
    """
    Calculate dipole field i.e. potential gradient
    Args:
        sources: position and dipole moments for the fields sources
        means: positions to calculate gradient at

    Returns:
        torch.Tensor meansX3 field at measurement positions
    """

    if recursive:
        def break_by_means():
            mid = int(means.shape[0] / 2)
            return torch.cat([field_grad(sources, means[:mid], eps, recursive, max_pts),
                              field_grad(sources, means[mid:], eps, recursive, max_pts)], dim=0)

        def break_by_sources():
            mid = int(sources.shape[0] / 2)
            return field_grad(sources[:mid], means, eps, recursive, max_pts) \
                   + field_grad(sources[mid:], means, eps, recursive, max_pts)

        if sources.shape[0] > max_pts and means.shape[0] > max_pts:
            if sources.shape[0] > means.shape[0]:
                return break_by_sources()
            else:
                return break_by_means()

        if sources.shape[0] > max_pts:
            return break_by_sources()

        if means.shape[0] > max_pts:
            return break_by_means()

    p = sources[:, 3:]
    R = sources[:, None, :3] - means[None, :, :3]
    R_unit = R / R.norm(dim=-1)[:, :, None]
    E = 3 * (p[:, None, :] * R_unit).sum(dim=-1)[:, :, None] * R_unit - p[:, None, :]
    E = E / (R.norm(dim=-1) ** 3 + eps)[:, :, None]
    E_total = E.sum(dim=0) * -1  # field=(-1)*grad -> flip the sign to get the gradient instead of -gradient
    E_total[E_total.isinf()] = 0
    E_total[E_total.isnan()] = 0
    return E_total


def reference_field(pc1, pc2):
    with torch.no_grad():
        E = field_grad(pc1, pc2, recursive=True)
        if pc2.shape[1] == 3:
            length = E.norm(dim=-1)
            E[length != 0, :] = E[length != 0, :] / length[length != 0, None]
            pc2 = torch.cat([pc2, E], dim=1)
        else:
            interactions = E * pc2[:, 3:]
            interactions = interactions.sum(dim=-1)
            sign = (interactions >= 0).float() * 2 - 1
            pc2[:, 3:] = pc2[:, 3:] * sign[:, None]

        return pc2


def strongest_field_propagation_reps(input_pc, reps, diffuse=False, weights=None):
    input_pc = input_pc.detach()
    with torch.no_grad():
        pts = input_pc

        if weights is not None:
            # factor in the weights for each point by scaling the normals
            weights = weights.clamp(0.1, 1)
            pts[:, 3:] = pts[:, 3:] * weights[:, None]
        device = input_pc.device

        remaining = []

        E = torch.zeros_like(pts[:, :3])
        patches = []
        oriented_pts_mask = torch.zeros(len(pts)).bool()
        non_oriented_pts_mask = torch.zeros(len(pts)).bool()
        for rep, rest in reps:
            remaining.append((rep, rest))
            patches.append(rep)
            non_oriented_pts_mask[rep] = True

        # find the flattest patch to start with
        curv = [util.pca_eigen_values(pts[patch]) for patch in patches]
        min_index = np.array([curv[i][0] for i in range(len(patches))])
        min_index = np.abs(min_index)
        min_index = np.argmin(min_index)

        # calculate the field from the initial patch
        start_patch, rest_patch = remaining.pop(min_index)
        oriented_pts_mask[start_patch] = True
        non_oriented_pts_mask[start_patch] = False
        E[non_oriented_pts_mask] = field_grad(pts[oriented_pts_mask], pts[non_oriented_pts_mask])

        # prop orientation as long as there are remaining unoriented patches
        while len(remaining) > 0:
            # calculate the interaction between the field and all remaining patches
            interaction = [(E[patch] * pts[patch, 3:]).sum(dim=-1).sum() for patch, rest in remaining]

            # orient the patch with the strongest interaction
            max_interaction_index = torch.tensor(interaction).abs().argmax().item()
            patch, rest_patch = remaining.pop(max_interaction_index)
            # print(f'{patch_index}')
            if interaction[max_interaction_index] < 0:
                pts[patch, 3:] *= -1
                pts[rest_patch, 3:] *= -1
            oriented_pts_mask[patch] = True
            non_oriented_pts_mask[patch] = False

            if diffuse:
                # add the effect of the current patch to *all* other patches
                patch_mask = torch.logical_or(oriented_pts_mask, non_oriented_pts_mask)
                patch_mask[patch] = False
                dE = field_grad(pts[patch], pts[patch_mask])
                E[patch_mask] = E[patch_mask] + dE
            else:
                # add the effect of the current patch only to the *remaining* patches
                dE = field_grad(pts[patch], pts[non_oriented_pts_mask])
                E[non_oriented_pts_mask] = E[non_oriented_pts_mask] + dE

        if diffuse:
            for rep, rest in reps:
                interactions = (E[rep] * pts[rep, 3:]).sum(dim=-1)
                sign = (interactions > 0).float() * 2 - 1
                pts[rep, 3:] = pts[rep, 3:] * sign[:, None]

        E = field_grad(pts[oriented_pts_mask], pts[~oriented_pts_mask])
        interactions = (E * pts[~oriented_pts_mask, 3:]).sum(dim=-1)
        sign = (interactions > 0).float() * 2 - 1
        pts[~oriented_pts_mask, 3:] = pts[~oriented_pts_mask, 3:] * sign[:, None]

        pts = pts.to(device)

        if weights is not None:
            # scale the normal back to unit because of previous weighted scaling
            pts[:, 3:] = pts[:, 3:] / weights[:, None]


def strongest_field_propagation(pts, patches, all_patches, diffuse=False, weights=None):
    with torch.no_grad():
        if weights is not None:
            # factor in the weights for each point by scaling the normals
            weights = weights.clamp(0.1, 1)
            pts[:, 3:] = pts[:, 3:] * weights[:, None]
        device = pts.device

        # initialize remaining
        remaining = []
        pts_mask = torch.zeros(len(pts)).bool()
        zeros = torch.zeros(len(pts)).bool()
        E = torch.zeros_like(pts[:, :3])
        for i in range(len(all_patches)):
            remaining.append((i, all_patches[i]))

        # find the flattest patch to start with
        curv = [util.pca_eigen_values(pts[patch]) for patch in all_patches]
        min_index = np.array([curv[i][0] for i in range(len(all_patches))])
        min_index = np.abs(min_index)
        min_index = np.argmin(min_index)

        # calculate the field from the initial patch
        _, start_patch = remaining.pop(min_index)
        pts_mask[start_patch] = True
        E[~pts_mask] = field_grad(pts[pts_mask], pts[~pts_mask])

        # prop orientation as long as there are remaining unoriented patches
        while len(remaining) > 0:
            # calculate the interaction between the field and all remaining patches
            interaction = [(E[patch] * pts[patch, 3:]).sum(dim=-1).sum() for i, patch in remaining]

            # orient the patch with the strongest interaction
            max_interaction_index = torch.tensor(interaction).abs().argmax().item()
            patch_index, patch = remaining.pop(max_interaction_index)
            # print(f'{patch_index}')
            if interaction[max_interaction_index] < 0:
                pts[patch, 3:] *= -1
            pts_mask[patch] = True

            if diffuse:
                # add the effect of the current patch to *all* other patches
                patch_mask = zeros.clone()
                patch_mask[patch] = True
                dE = field_grad(pts[patch], pts[~patch_mask])
                E[~patch_mask] = E[~patch_mask] + dE
            else:
                # add the effect of the current patch only to the *remaining* patches
                dE = field_grad(pts[patch], pts[~pts_mask])
                E[~pts_mask] = E[~pts_mask] + dE

        if diffuse:
            for patch in patches:
                patch = patch[1]
                interactions = (E[patch] * pts[patch, 3:]).sum(dim=-1)
                sign = (interactions > 0).float() * 2 - 1
                pts[patch, 3:] = pts[patch, 3:] * sign[:, None]

        pts = pts.to(device)

        if weights is not None:
            # scale the normal back to unit because of previous weighted scaling
            pts[:, 3:] = pts[:, 3:] / weights[:, None]


def strongest_field_propagation_points(pts: torch.Tensor, diffuse=False, starting_point=0):
        device = pts.device
        pts = pts.cuda()
        indx = torch.arange(pts.shape[0]).to(pts.device)

        E = torch.zeros_like(pts[:, :3])
        visited = torch.zeros_like(pts[:, 0]).bool()
        visited[starting_point] = True
        E[~(indx == starting_point)] += field_grad(pts[starting_point:(starting_point + 1)],
                                                  pts[~(indx == starting_point), :3], eps=1e-6)

        # prop orientation as long as there are remaining unoriented points
        while not visited.all():
            # calculate the interaction between the field and all remaining patches
            interaction = (E[~visited] * pts[~visited, 3:]).sum(dim=-1)

            # orient the patch with the strongest interaction
            max_interaction_index = interaction.abs().argmax()
            pts_index = indx[~visited][max_interaction_index]
            # print(f'{patch_index}')
            if interaction[max_interaction_index] < 0:
                pts[pts_index, 3:] *= -1
            visited[pts_index] = True

            E[~(indx == pts_index)] += field_grad(pts[pts_index:(pts_index + 1)],
                                                      pts[~(indx == pts_index), :3], eps=1e-6)

        if diffuse:
            interactions = (E * pts[:, 3:]).sum(dim=-1)
            sign = (interactions > 0).float() * 2 - 1
            pts[:, 3:] = pts[:, 3:] * sign[:, None]

        pts = pts.to(device)
back to top