https://github.com/galmetzer/dipole-normal-prop
Raw File
Tip revision: 0887b893b153f4ce6b09e4dd485f9b042f15b706 authored by Gal Metzer on 08 September 2021, 21:53 UTC
Update README.md
Tip revision: 0887b89
pointcnn.py
import torch
import torch.nn.functional as F
from torch.nn import Linear as Lin
import torch.nn as nn
from torch_geometric.nn import fps, global_mean_pool, knn, knn_graph, knn_interpolate
from models.XConv import XConv
from torch_geometric.data.data import Data
import util as utils


class PointCNN(torch.nn.Module):
    def __init__(self, opts, input_feat, num_classes):
        super(PointCNN, self).__init__()
        self.opts = opts
        self.num_classes = num_classes
        self.input_features = input_feat
        self.encoder = PointCnnEncoderPool(self.input_features, pooling_ratio=self.opts.pool)

        self.output_channels = 2
        self.decoder = PointCnnDecoderPool(self.output_channels)

    def forward(self, data_in: torch.Tensor, batch_index=None):
        '''
        data: a batch of input, torch.Tensor or torch_geometric.data.Data type
            - torch.Tensor: (batch_size, 3, num_points), as common batch input
        '''

        n_points = data_in.shape[0]
        if batch_index is None:
            batch_index = torch.zeros(n_points).long().to(data_in.device)

        pos = data_in
        data = Data()
        data.x, data.pos, data.batch = pos, pos[:, :3].detach(), batch_index

        for i in range(batch_index.max() + 1):
            data.x[batch_index == i, :3] -= data.x[batch_index == i, :3].mean(dim=0)
            data.x[batch_index == i] = utils.rotate_to_principle_components(data.x[batch_index == i])


        if not hasattr(data, 'x'):
            data.x = None

        latent = self.encoder(data)
        x = self.decoder(*latent)

        return x.view(n_points, self.output_channels)


class PointCnnEncoderPool(nn.Module):
    def __init__(self, input_feat, pooling_ratio=0.5):
        super(PointCnnEncoderPool, self).__init__()
        self.pooling_ratio = pooling_ratio
        self.conv1 = XConv(input_feat, 32, dim=3, kernel_size=16, hidden_channels=None)
        self.conv2 = XConv(
            32, 64, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
        self.conv3 = XConv(
            64, 128, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
        self.conv4 = XConv(
            128, 256, dim=3, kernel_size=20, hidden_channels=None, dilation=1)


    def forward(self, data_in):
        x, pos, batch = data_in.x, data_in.pos, data_in.batch
        pos1, batch1 = pos.clone(), batch.clone()
        x = F.relu(self.conv1(x, pos.clone(), batch))

        idx = fps(pos, batch, ratio=self.pooling_ratio, random_start=True)
        x, pos, batch = x[idx], pos[idx], batch[idx]
        pos2, batch2 = pos.clone(), batch.clone()
        x = F.relu(self.conv2(x, pos.clone(), batch))

        idx = fps(pos, batch, ratio=self.pooling_ratio, random_start=True)
        x, pos, batch = x[idx], pos[idx], batch[idx]
        x = F.relu(self.conv3(x, pos.clone(), batch))

        x = F.relu(self.conv4(x, pos, batch))
        return (x, pos, batch), (pos2, pos1), (batch2, batch1)


class PointCnnDecoderPool(nn.Module):
    def __init__(self, input_feat):
        super(PointCnnDecoderPool, self).__init__()
        self.conv1 = XConv(256, 128, dim=3, kernel_size=20, hidden_channels=None)
        self.conv2 = XConv(
            128, 64, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
        self.conv3 = XConv(
            64, 32, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
        self.lin4 = Lin(32, input_feat)

    def forward(self, data_in, poss, batchs):
        x, pos, batch = data_in

        def upsamp(i, x, pos, batch):
            x = knn_interpolate(x, pos, poss[i], batch_x=batch, batch_y=batchs[i], k=16, num_workers=1)
            pos = poss[i]
            batch = batchs[i]
            return x, pos, batch

        x = F.relu(self.conv1(x, pos.clone(), batch))

        x, pos, batch = upsamp(0, x, pos, batch)
        x = F.relu(self.conv2(x, pos.clone(), batch))

        x, pos, batch = upsamp(1, x, pos, batch)
        x = F.relu(self.conv3(x, pos.clone(), batch))

        x = self.lin4(x)
        return x




back to top