https://github.com/galmetzer/dipole-normal-prop
Tip revision: 0887b893b153f4ce6b09e4dd485f9b042f15b706 authored by Gal Metzer on 08 September 2021, 21:53:11 UTC
Update README.md
Update README.md
Tip revision: 0887b89
pointcnn.py
import torch
import torch.nn.functional as F
from torch.nn import Linear as Lin
import torch.nn as nn
from torch_geometric.nn import fps, global_mean_pool, knn, knn_graph, knn_interpolate
from models.XConv import XConv
from torch_geometric.data.data import Data
import util as utils
class PointCNN(torch.nn.Module):
def __init__(self, opts, input_feat, num_classes):
super(PointCNN, self).__init__()
self.opts = opts
self.num_classes = num_classes
self.input_features = input_feat
self.encoder = PointCnnEncoderPool(self.input_features, pooling_ratio=self.opts.pool)
self.output_channels = 2
self.decoder = PointCnnDecoderPool(self.output_channels)
def forward(self, data_in: torch.Tensor, batch_index=None):
'''
data: a batch of input, torch.Tensor or torch_geometric.data.Data type
- torch.Tensor: (batch_size, 3, num_points), as common batch input
'''
n_points = data_in.shape[0]
if batch_index is None:
batch_index = torch.zeros(n_points).long().to(data_in.device)
pos = data_in
data = Data()
data.x, data.pos, data.batch = pos, pos[:, :3].detach(), batch_index
for i in range(batch_index.max() + 1):
data.x[batch_index == i, :3] -= data.x[batch_index == i, :3].mean(dim=0)
data.x[batch_index == i] = utils.rotate_to_principle_components(data.x[batch_index == i])
if not hasattr(data, 'x'):
data.x = None
latent = self.encoder(data)
x = self.decoder(*latent)
return x.view(n_points, self.output_channels)
class PointCnnEncoderPool(nn.Module):
def __init__(self, input_feat, pooling_ratio=0.5):
super(PointCnnEncoderPool, self).__init__()
self.pooling_ratio = pooling_ratio
self.conv1 = XConv(input_feat, 32, dim=3, kernel_size=16, hidden_channels=None)
self.conv2 = XConv(
32, 64, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
self.conv3 = XConv(
64, 128, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
self.conv4 = XConv(
128, 256, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
def forward(self, data_in):
x, pos, batch = data_in.x, data_in.pos, data_in.batch
pos1, batch1 = pos.clone(), batch.clone()
x = F.relu(self.conv1(x, pos.clone(), batch))
idx = fps(pos, batch, ratio=self.pooling_ratio, random_start=True)
x, pos, batch = x[idx], pos[idx], batch[idx]
pos2, batch2 = pos.clone(), batch.clone()
x = F.relu(self.conv2(x, pos.clone(), batch))
idx = fps(pos, batch, ratio=self.pooling_ratio, random_start=True)
x, pos, batch = x[idx], pos[idx], batch[idx]
x = F.relu(self.conv3(x, pos.clone(), batch))
x = F.relu(self.conv4(x, pos, batch))
return (x, pos, batch), (pos2, pos1), (batch2, batch1)
class PointCnnDecoderPool(nn.Module):
def __init__(self, input_feat):
super(PointCnnDecoderPool, self).__init__()
self.conv1 = XConv(256, 128, dim=3, kernel_size=20, hidden_channels=None)
self.conv2 = XConv(
128, 64, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
self.conv3 = XConv(
64, 32, dim=3, kernel_size=20, hidden_channels=None, dilation=1)
self.lin4 = Lin(32, input_feat)
def forward(self, data_in, poss, batchs):
x, pos, batch = data_in
def upsamp(i, x, pos, batch):
x = knn_interpolate(x, pos, poss[i], batch_x=batch, batch_y=batchs[i], k=16, num_workers=1)
pos = poss[i]
batch = batchs[i]
return x, pos, batch
x = F.relu(self.conv1(x, pos.clone(), batch))
x, pos, batch = upsamp(0, x, pos, batch)
x = F.relu(self.conv2(x, pos.clone(), batch))
x, pos, batch = upsamp(1, x, pos, batch)
x = F.relu(self.conv3(x, pos.clone(), batch))
x = self.lin4(x)
return x