https://github.com/czq142857/NDC
Raw File
Tip revision: 9054e20f55013d031af9e3a2c91f5cab75837bc4 authored by Zhiqin Chen on 25 August 2022, 01:13:23 UTC
Update dataset.py
Tip revision: 9054e20
modelpc.py
import torch
import torch.nn as nn
import torch.nn.functional as F


#number of local pointnet input points, gathered via K nearest neighbors
KNN_num = 8


class pc_conv_first(nn.Module):
    def __init__(self, ef_dim):
        super(pc_conv_first, self).__init__()
        self.ef_dim = ef_dim
        self.linear_1 = nn.Linear(3, self.ef_dim)
        self.linear_2 = nn.Linear(self.ef_dim, self.ef_dim)

    def forward(self, KNN_xyz):
        output = KNN_xyz
        #[newpointnum*KNN_num,3]
        output = self.linear_1(output)
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        output = self.linear_2(output)
        #[newpointnum*KNN_num,ef_dim]
        output = output.view(-1,KNN_num,self.ef_dim)
        #[newpointnum,KNN_num,ef_dim]
        output = torch.max(output,1)[0]
        #[newpointnum,ef_dim]
        return output


class pc_conv(nn.Module):
    def __init__(self, ef_dim):
        super(pc_conv, self).__init__()
        self.ef_dim = ef_dim
        self.linear_1 = nn.Linear(self.ef_dim+3, self.ef_dim)
        self.linear_2 = nn.Linear(self.ef_dim, self.ef_dim)

    def forward(self, input, KNN_idx, KNN_xyz):
        output = input
        #[pointnum,ef_dim]
        output = output[KNN_idx]
        #[newpointnum*KNN_num,ef_dim]
        output = torch.cat([output,KNN_xyz],1)
        #[newpointnum*KNN_num,ef_dim+3]
        output = self.linear_1(output)
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        output = self.linear_2(output)
        #[newpointnum*KNN_num,ef_dim]
        output = output.view(-1,KNN_num,self.ef_dim)
        #[newpointnum,KNN_num,ef_dim]
        output = torch.max(output,1)[0]
        #[newpointnum,ef_dim]
        return output


class pc_resnet_block(nn.Module):
    def __init__(self, ef_dim):
        super(pc_resnet_block, self).__init__()
        self.ef_dim = ef_dim
        self.linear_1 = nn.Linear(self.ef_dim, self.ef_dim)
        self.linear_2 = nn.Linear(self.ef_dim, self.ef_dim)

    def forward(self, input):
        output = self.linear_1(input)
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        output = self.linear_2(output)
        output = output+input
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        return output


class resnet_block(nn.Module):
    def __init__(self, ef_dim):
        super(resnet_block, self).__init__()
        self.ef_dim = ef_dim
        self.pc_conv_1 = nn.Conv3d(self.ef_dim, self.ef_dim, 1, stride=1, padding=0, bias=True)
        self.pc_conv_2 = nn.Conv3d(self.ef_dim, self.ef_dim, 1, stride=1, padding=0, bias=True)

    def forward(self, input):
        output = self.pc_conv_1(input)
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        output = self.pc_conv_2(output)
        output = output+input
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        return output


class resnet_block_rec3(nn.Module):
    def __init__(self, ef_dim):
        super(resnet_block_rec3, self).__init__()
        self.ef_dim = ef_dim
        self.pc_conv_1 = nn.Conv3d(self.ef_dim, self.ef_dim, 3, stride=1, padding=1, bias=True)
        self.pc_conv_2 = nn.Conv3d(self.ef_dim, self.ef_dim, 1, stride=1, padding=0, bias=True)

    def forward(self, input):
        output = self.pc_conv_1(input)
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        output = self.pc_conv_2(output)
        output = output+input
        output = F.leaky_relu(output, negative_slope=0.01, inplace=True)
        return output


class local_pointnet(nn.Module):

    def __init__(self, out_bool, out_float):
        super(local_pointnet, self).__init__()
        self.ef_dim = 128
        self.out_bool = out_bool
        self.out_float = out_float
        
        self.pc_conv_0 = pc_conv_first(self.ef_dim)

        self.pc_res_1 = pc_resnet_block(self.ef_dim)
        self.pc_conv_1 = pc_conv(self.ef_dim)

        self.pc_res_2 = pc_resnet_block(self.ef_dim)
        self.pc_conv_2 = pc_conv(self.ef_dim)

        self.pc_res_3 = pc_resnet_block(self.ef_dim)
        self.pc_conv_3 = pc_conv(self.ef_dim)

        self.pc_res_4 = pc_resnet_block(self.ef_dim)
        self.pc_conv_4 = pc_conv(self.ef_dim)

        self.pc_res_5 = pc_resnet_block(self.ef_dim)
        self.pc_conv_5 = pc_conv(self.ef_dim)

        self.pc_res_6 = pc_resnet_block(self.ef_dim)
        self.pc_conv_6 = pc_conv(self.ef_dim)

        self.pc_res_7 = pc_resnet_block(self.ef_dim)

        self.conv_1 = nn.Conv3d(self.ef_dim, self.ef_dim, 3, stride=1, padding=1, bias=True)
        self.conv_2 = nn.Conv3d(self.ef_dim, self.ef_dim, 3, stride=1, padding=1, bias=True)
        self.conv_3 = nn.Conv3d(self.ef_dim, self.ef_dim, 3, stride=1, padding=1, bias=True)

        self.conv_4 = nn.Linear(self.ef_dim, self.ef_dim)
        self.conv_5 = nn.Linear(self.ef_dim, self.ef_dim)

        if self.out_bool:
            self.pc_conv_out_bool = nn.Linear(self.ef_dim, 3)
        if self.out_float:
            self.pc_conv_out_float = nn.Linear(self.ef_dim, 3)

    def forward(self, pc_KNN_idx,pc_KNN_xyz, voxel_xyz_int,voxel_KNN_idx,voxel_KNN_xyz):
        out = pc_KNN_xyz
        
        out = self.pc_conv_0(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_1(out)
        out = self.pc_conv_1(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_2(out)
        out = self.pc_conv_2(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_3(out)
        out = self.pc_conv_3(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_4(out)
        out = self.pc_conv_4(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_5(out)
        out = self.pc_conv_5(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_6(out)
        out = self.pc_conv_6(out, voxel_KNN_idx, voxel_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_7(out)

        voxel_xyz_int_max = torch.max(voxel_xyz_int,0)[0]
        voxel_xyz_int_min = torch.min(voxel_xyz_int,0)[0]
        voxel_xyz_int_size = voxel_xyz_int_max-voxel_xyz_int_min+1
        voxel_xyz_int = voxel_xyz_int-voxel_xyz_int_min.view(1,-1)
        tmp_grid = torch.zeros(voxel_xyz_int_size[0],voxel_xyz_int_size[1],voxel_xyz_int_size[2],self.ef_dim, device=out.device)
        tmp_grid[voxel_xyz_int[:,0],voxel_xyz_int[:,1],voxel_xyz_int[:,2]] = out
        tmp_grid = tmp_grid.permute(3,0,1,2).unsqueeze(0)
        out = tmp_grid


        out = self.conv_1(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.conv_2(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.conv_3(out)
        
        out = out.squeeze(0).permute(1,2,3,0)
        out = out[voxel_xyz_int[:,0],voxel_xyz_int[:,1],voxel_xyz_int[:,2]]

        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.conv_4(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)
        
        out = self.conv_5(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)


        if self.out_bool and self.out_float:
            out_bool = self.pc_conv_out_bool(out)
            out_float = self.pc_conv_out_float(out)
            return torch.sigmoid(out_bool), out_float
        elif self.out_bool:
            out_bool = self.pc_conv_out_bool(out)
            return torch.sigmoid(out_bool)
        elif self.out_float:
            out_float = self.pc_conv_out_float(out)
            return out_float


class local_pointnet_larger(nn.Module):

    def __init__(self, out_bool, out_float):
        super(local_pointnet_larger, self).__init__()
        self.ef_dim = 128
        self.out_bool = out_bool
        self.out_float = out_float
        
        self.pc_conv_0 = pc_conv_first(self.ef_dim)

        self.pc_res_1 = pc_resnet_block(self.ef_dim)
        self.pc_conv_1 = pc_conv(self.ef_dim)

        self.pc_res_2 = pc_resnet_block(self.ef_dim)
        self.pc_conv_2 = pc_conv(self.ef_dim)

        self.pc_res_3 = pc_resnet_block(self.ef_dim)
        self.pc_conv_3 = pc_conv(self.ef_dim)

        self.pc_res_4 = pc_resnet_block(self.ef_dim)
        self.pc_conv_4 = pc_conv(self.ef_dim)

        self.pc_res_5 = pc_resnet_block(self.ef_dim)
        self.pc_conv_5 = pc_conv(self.ef_dim)

        self.pc_res_6 = pc_resnet_block(self.ef_dim)
        self.pc_conv_6 = pc_conv(self.ef_dim)

        self.pc_res_7 = pc_resnet_block(self.ef_dim)

        self.res_1 = resnet_block_rec3(self.ef_dim)
        self.res_2 = resnet_block_rec3(self.ef_dim)
        self.res_3 = resnet_block_rec3(self.ef_dim)
        self.res_4 = resnet_block_rec3(self.ef_dim)
        self.res_5 = resnet_block_rec3(self.ef_dim)
        self.res_6 = resnet_block_rec3(self.ef_dim)
        self.res_7 = resnet_block_rec3(self.ef_dim)
        self.res_8 = resnet_block_rec3(self.ef_dim)

        self.linear_1 = nn.Linear(self.ef_dim, self.ef_dim)
        self.linear_2 = nn.Linear(self.ef_dim, self.ef_dim)

        if self.out_bool:
            self.linear_bool = nn.Linear(self.ef_dim, 3)
        if self.out_float:
            self.linear_float = nn.Linear(self.ef_dim, 3)

    def forward(self, pc_KNN_idx,pc_KNN_xyz, voxel_xyz_int,voxel_KNN_idx,voxel_KNN_xyz):
        out = pc_KNN_xyz
        
        out = self.pc_conv_0(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_1(out)
        out = self.pc_conv_1(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_2(out)
        out = self.pc_conv_2(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_3(out)
        out = self.pc_conv_3(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_4(out)
        out = self.pc_conv_4(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_5(out)
        out = self.pc_conv_5(out, pc_KNN_idx, pc_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_6(out)
        out = self.pc_conv_6(out, voxel_KNN_idx, voxel_KNN_xyz)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.pc_res_7(out)

        voxel_xyz_int_max = torch.max(voxel_xyz_int,0)[0]
        voxel_xyz_int_min = torch.min(voxel_xyz_int,0)[0]
        voxel_xyz_int_size = voxel_xyz_int_max-voxel_xyz_int_min+1
        voxel_xyz_int = voxel_xyz_int-voxel_xyz_int_min.view(1,-1)
        tmp_grid = torch.zeros(voxel_xyz_int_size[0],voxel_xyz_int_size[1],voxel_xyz_int_size[2],self.ef_dim, device=out.device)
        tmp_grid[voxel_xyz_int[:,0],voxel_xyz_int[:,1],voxel_xyz_int[:,2]] = out
        tmp_grid = tmp_grid.permute(3,0,1,2).unsqueeze(0)
        out = tmp_grid

        out = self.res_1(out)
        out = self.res_2(out)
        out = self.res_3(out)
        out = self.res_4(out)
        out = self.res_5(out)
        out = self.res_6(out)
        out = self.res_7(out)
        out = self.res_8(out)
        
        out = out.squeeze(0).permute(1,2,3,0)
        out = out[voxel_xyz_int[:,0],voxel_xyz_int[:,1],voxel_xyz_int[:,2]]

        out = self.linear_1(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)
        out = self.linear_2(out)
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)


        if self.out_bool and self.out_float:
            out_bool = self.linear_bool(out)
            out_float = self.linear_float(out)
            return torch.sigmoid(out_bool), out_float
        elif self.out_bool:
            out_bool = self.linear_bool(out)
            return torch.sigmoid(out_bool)
        elif self.out_float:
            out_float = self.linear_float(out)
            return out_float




def postprocessing(pred_output_bool):
    for t in range(2):

        #open edges
        gridedge_x_outedge_y_0 = pred_output_bool[:-1, :,   :-1, 0]
        gridedge_x_outedge_y_1 = pred_output_bool[:-1, :,   1: , 0]
        gridedge_x_outedge_z_0 = pred_output_bool[:-1, :-1, :  , 0]
        gridedge_x_outedge_z_1 = pred_output_bool[:-1, 1:,  :  , 0]
        gridedge_y_outedge_x_0 = pred_output_bool[:,   :-1, :-1, 1]
        gridedge_y_outedge_x_1 = pred_output_bool[:,   :-1, 1: , 1]
        gridedge_y_outedge_z_0 = pred_output_bool[:-1, :-1, :  , 1]
        gridedge_y_outedge_z_1 = pred_output_bool[1:,  :-1, :  , 1]
        gridedge_z_outedge_x_0 = pred_output_bool[:,   :-1, :-1, 2]
        gridedge_z_outedge_x_1 = pred_output_bool[:,   1:,  :-1, 2]
        gridedge_z_outedge_y_0 = pred_output_bool[:-1, :,   :-1, 2]
        gridedge_z_outedge_y_1 = pred_output_bool[1:,  :,   :-1, 2]
        outedge_x = gridedge_y_outedge_x_0+gridedge_y_outedge_x_1+gridedge_z_outedge_x_0+gridedge_z_outedge_x_1
        outedge_y = gridedge_x_outedge_y_0+gridedge_x_outedge_y_1+gridedge_z_outedge_y_0+gridedge_z_outedge_y_1
        outedge_z = gridedge_x_outedge_z_0+gridedge_x_outedge_z_1+gridedge_y_outedge_z_0+gridedge_y_outedge_z_1
        boundary_x_flag = (outedge_x==1).int()
        boundary_y_flag = (outedge_y==1).int()
        boundary_z_flag = (outedge_z==1).int()

        tmp_int = torch.zeros(pred_output_bool.size(), dtype=torch.int32, device=pred_output_bool.device)
        tmp_int[:,    :-1, :-1, 1] += boundary_x_flag
        tmp_int[:,    :-1, 1: , 1] += boundary_x_flag
        tmp_int[:,    :-1, :-1, 2] += boundary_x_flag
        tmp_int[:,    1:,  :-1, 2] += boundary_x_flag
        tmp_int[:-1, :,    :-1, 0] += boundary_y_flag
        tmp_int[:-1, :,    1: , 0] += boundary_y_flag
        tmp_int[:-1, :,    :-1, 2] += boundary_y_flag
        tmp_int[1:,  :,    :-1, 2] += boundary_y_flag
        tmp_int[:-1, :-1, :   , 0] += boundary_z_flag
        tmp_int[:-1, 1:,  :   , 0] += boundary_z_flag
        tmp_int[:-1, :-1, :   , 1] += boundary_z_flag
        tmp_int[1:,  :-1, :   , 1] += boundary_z_flag

        #create a quad if meet 3 open edges
        pred_output_bool = torch.max( pred_output_bool, (tmp_int>=3).int() )

        #delete a quad if meet 3 open edges
        pred_output_bool = torch.min( pred_output_bool, (tmp_int<3).int() )


    for t in range(1): #radical

        #open edges
        gridedge_x_outedge_y_0 = pred_output_bool[:-1, :,   :-1, 0]
        gridedge_x_outedge_y_1 = pred_output_bool[:-1, :,   1: , 0]
        gridedge_x_outedge_z_0 = pred_output_bool[:-1, :-1, :  , 0]
        gridedge_x_outedge_z_1 = pred_output_bool[:-1, 1:,  :  , 0]
        gridedge_y_outedge_x_0 = pred_output_bool[:,   :-1, :-1, 1]
        gridedge_y_outedge_x_1 = pred_output_bool[:,   :-1, 1: , 1]
        gridedge_y_outedge_z_0 = pred_output_bool[:-1, :-1, :  , 1]
        gridedge_y_outedge_z_1 = pred_output_bool[1:,  :-1, :  , 1]
        gridedge_z_outedge_x_0 = pred_output_bool[:,   :-1, :-1, 2]
        gridedge_z_outedge_x_1 = pred_output_bool[:,   1:,  :-1, 2]
        gridedge_z_outedge_y_0 = pred_output_bool[:-1, :,   :-1, 2]
        gridedge_z_outedge_y_1 = pred_output_bool[1:,  :,   :-1, 2]
        outedge_x = gridedge_y_outedge_x_0+gridedge_y_outedge_x_1+gridedge_z_outedge_x_0+gridedge_z_outedge_x_1
        outedge_y = gridedge_x_outedge_y_0+gridedge_x_outedge_y_1+gridedge_z_outedge_y_0+gridedge_z_outedge_y_1
        outedge_z = gridedge_x_outedge_z_0+gridedge_x_outedge_z_1+gridedge_y_outedge_z_0+gridedge_y_outedge_z_1
        boundary_x_flag = (outedge_x==1).int()
        boundary_y_flag = (outedge_y==1).int()
        boundary_z_flag = (outedge_z==1).int()

        tmp_int = torch.zeros(pred_output_bool.size(), dtype=torch.int32, device=pred_output_bool.device)
        tmp_int[:,    :-1, :-1, 1] += boundary_x_flag
        tmp_int[:,    :-1, 1: , 1] += boundary_x_flag
        tmp_int[:,    :-1, :-1, 2] += boundary_x_flag
        tmp_int[:,    1:,  :-1, 2] += boundary_x_flag
        tmp_int[:-1, :,    :-1, 0] += boundary_y_flag
        tmp_int[:-1, :,    1: , 0] += boundary_y_flag
        tmp_int[:-1, :,    :-1, 2] += boundary_y_flag
        tmp_int[1:,  :,    :-1, 2] += boundary_y_flag
        tmp_int[:-1, :-1, :   , 0] += boundary_z_flag
        tmp_int[:-1, 1:,  :   , 0] += boundary_z_flag
        tmp_int[:-1, :-1, :   , 1] += boundary_z_flag
        tmp_int[1:,  :-1, :   , 1] += boundary_z_flag

        #create a quad if meet 2 open edges, only if it helps close a hole, see below code
        pred_output_bool_backup = pred_output_bool
        pred_output_bool = torch.max( pred_output_bool, (tmp_int>=2).int() )

        #open edges
        gridedge_x_outedge_y_0 = pred_output_bool[:-1, :,   :-1, 0]
        gridedge_x_outedge_y_1 = pred_output_bool[:-1, :,   1: , 0]
        gridedge_x_outedge_z_0 = pred_output_bool[:-1, :-1, :  , 0]
        gridedge_x_outedge_z_1 = pred_output_bool[:-1, 1:,  :  , 0]
        gridedge_y_outedge_x_0 = pred_output_bool[:,   :-1, :-1, 1]
        gridedge_y_outedge_x_1 = pred_output_bool[:,   :-1, 1: , 1]
        gridedge_y_outedge_z_0 = pred_output_bool[:-1, :-1, :  , 1]
        gridedge_y_outedge_z_1 = pred_output_bool[1:,  :-1, :  , 1]
        gridedge_z_outedge_x_0 = pred_output_bool[:,   :-1, :-1, 2]
        gridedge_z_outedge_x_1 = pred_output_bool[:,   1:,  :-1, 2]
        gridedge_z_outedge_y_0 = pred_output_bool[:-1, :,   :-1, 2]
        gridedge_z_outedge_y_1 = pred_output_bool[1:,  :,   :-1, 2]
        outedge_x = gridedge_y_outedge_x_0+gridedge_y_outedge_x_1+gridedge_z_outedge_x_0+gridedge_z_outedge_x_1
        outedge_y = gridedge_x_outedge_y_0+gridedge_x_outedge_y_1+gridedge_z_outedge_y_0+gridedge_z_outedge_y_1
        outedge_z = gridedge_x_outedge_z_0+gridedge_x_outedge_z_1+gridedge_y_outedge_z_0+gridedge_y_outedge_z_1
        boundary_x_flag = (outedge_x==1).int()
        boundary_y_flag = (outedge_y==1).int()
        boundary_z_flag = (outedge_z==1).int()

        tmp_int = torch.zeros(pred_output_bool.size(), dtype=torch.int32, device=pred_output_bool.device)
        tmp_int[:,    :-1, :-1, 1] += boundary_x_flag
        tmp_int[:,    :-1, 1: , 1] += boundary_x_flag
        tmp_int[:,    :-1, :-1, 2] += boundary_x_flag
        tmp_int[:,    1:,  :-1, 2] += boundary_x_flag
        tmp_int[:-1, :,    :-1, 0] += boundary_y_flag
        tmp_int[:-1, :,    1: , 0] += boundary_y_flag
        tmp_int[:-1, :,    :-1, 2] += boundary_y_flag
        tmp_int[1:,  :,    :-1, 2] += boundary_y_flag
        tmp_int[:-1, :-1, :   , 0] += boundary_z_flag
        tmp_int[:-1, 1:,  :   , 0] += boundary_z_flag
        tmp_int[:-1, :-1, :   , 1] += boundary_z_flag
        tmp_int[1:,  :-1, :   , 1] += boundary_z_flag

        pred_output_bool = torch.min( pred_output_bool, (tmp_int<2).int() )
        pred_output_bool = torch.max( pred_output_bool, pred_output_bool_backup )

    return pred_output_bool



back to top