https://github.com/EdwardSmith1884/GEOMetrics
Raw File
Tip revision: a39d4a45dfd33c257ff0f68069a5a3072bda7071 authored by Edward Smith on 20 October 2020, 21:41:38 UTC
for another project
Tip revision: a39d4a4
models.py
import torch.nn as nn
import torch 
import torch.nn.functional as F
from layers import *
import torchvision.models as models






class VGG(nn.Module):
	def __init__(self):
		super(VGG, self).__init__()

		self.layer1 = nn.Sequential(
			nn.Conv2d(4, 16, kernel_size=3, padding=1),	
			nn.BatchNorm2d(16),	
			nn.ReLU(inplace=True))

		self.layer2 = nn.Sequential(
			nn.Conv2d(16, 16, kernel_size=3, padding=1),
			nn.BatchNorm2d(16),	
			nn.ReLU(inplace=True))

		self.layer3 = nn.Sequential(
			nn.Conv2d(16, 32, kernel_size=3, padding=1, stride = 2),
			nn.BatchNorm2d(32),	
			nn.ReLU(inplace=True))

		self.layer4 = nn.Sequential(
			nn.Conv2d(32, 32, kernel_size=3, padding=1),
			nn.BatchNorm2d(32),	
			nn.ReLU(inplace=True))	

		self.layer5 = nn.Sequential(
			nn.Conv2d(32, 32, kernel_size=3, padding=1),
			nn.BatchNorm2d(32),	
			nn.ReLU(inplace=True))

		self.layer6 = nn.Sequential(
			nn.Conv2d(32, 64, kernel_size=3, padding=1, stride = 2),
			nn.BatchNorm2d(64),	
			nn.ReLU(inplace=True))

		self.layer7 = nn.Sequential(
			nn.Conv2d(64, 64, kernel_size=3, padding=1),
			nn.BatchNorm2d(64),	
			nn.ReLU(inplace=True))

		self.layer8 = nn.Sequential(
			nn.Conv2d(64, 64, kernel_size=3, padding=1),
			nn.BatchNorm2d(64),	
			nn.ReLU(inplace=True))

		self.layer9 = nn.Sequential(
			nn.Conv2d(64, 128, kernel_size=3, padding=1, stride = 2),
			nn.BatchNorm2d(128),	
			nn.ReLU(inplace=True))

		self.layer10 = nn.Sequential(
			nn.Conv2d(128, 128, kernel_size=3, padding=1),
			nn.BatchNorm2d(128),	
			nn.ReLU(inplace=True))

		self.layer11 = nn.Sequential(
			nn.Conv2d(128, 128, kernel_size=3, padding=1),
			nn.BatchNorm2d(128),	
			nn.ReLU(inplace=True))

		self.layer12 = nn.Sequential(
			nn.Conv2d(128, 256, kernel_size=3, padding=1, stride = 2),
			nn.BatchNorm2d(256),	
			nn.ReLU(inplace=True))

		self.layer13 = nn.Sequential(
			nn.Conv2d(256, 256, kernel_size=3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True))

	
		self.layer14 = nn.Sequential(
			nn.Conv2d(256, 256, kernel_size=3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(inplace=True))

		self.layer15 = nn.Sequential(
			nn.Conv2d(256, 512, kernel_size=3, padding=1, stride = 2),
			nn.BatchNorm2d(512),
			nn.ReLU(inplace=True))

		self.layer16 = nn.Sequential(
			nn.Conv2d(512, 512, kernel_size=3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(inplace=True))

		self.layer17 = nn.Sequential(
			nn.Conv2d(512, 512, kernel_size=3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(inplace=True))
		self.layer18 = nn.Sequential(
			nn.Conv2d(512, 512, kernel_size=3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(inplace=True))


	def forward(self, tensor):
		x = self.layer1(tensor)
		x = self.layer2(x)
		x = self.layer3(x)
		x = self.layer4(x)
		x = self.layer5(x)
		x = self.layer6(x)
		x = self.layer7(x)
		x = self.layer8(x)
		A = x 
		x = self.layer9(x) 
		x = self.layer10(x)
		x = self.layer11(x)
		B = x 
		x = self.layer12(x)
		x = self.layer13(x)
		x = self.layer14(x)
		C = x
		x = self.layer15(x)
		x = self.layer16(x)
		x = self.layer17(x)
		D = self.layer18(x)
	
		return A,B,C,D







class MeshDeformationBlock(nn.Module):
	def __init__(self, input_features, hidden = 192, output_features = 3):
		super(MeshDeformationBlock, self).__init__()
		self.gc1 = Image_ZERON_GCNGCN(input_features, hidden)
		self.gc2 = Image_ZERON_GCNGCN(hidden, hidden)
		self.gc3 = Image_ZERON_GCNGCN(hidden , hidden)
		self.gc4 = Image_ZERON_GCNGCN(hidden, hidden)
		self.gc5 = Image_ZERON_GCNGCN(hidden , hidden)
		self.gc6 = Image_ZERON_GCNGCN(hidden, hidden)
		self.gc7 = Image_ZERON_GCNGCN(hidden , hidden)
		self.gc8 = Image_ZERON_GCNGCN(hidden, hidden)
		self.gc9 = Image_ZERON_GCNGCN(hidden , hidden)
		self.gc10 = Image_ZERON_GCNGCN(hidden, hidden)
		self.gc11 = Image_ZERON_GCNGCN(hidden , hidden)
		self.gc12 = Image_ZERON_GCNGCN(hidden, hidden)
		self.gc13 = Image_ZERON_GCNGCN(hidden , hidden)
		self.gc15 = Image_ZERON_GCNGCN(hidden,  output_features)
		self.hidden = hidden
	def forward(self, features, pooled , adj):
		full_features = torch.cat((features,pooled), dim = 1)
		#1
		x = (self.gc1(full_features, adj, F.relu))      
		x = (self.gc2(x, adj, F.relu))
		x = full_features[:,:self.hidden] +x
		features = x 
		features/= 2 

		#2
		x = (self.gc3(features, adj, F.relu))      
		x = (self.gc4(x, adj, F.relu))
		features = features +x 
		features/= 2 
		#3
		x = (self.gc5(features, adj, F.relu))      
		x = (self.gc6(x, adj, F.relu))
		features = features +x 
		features/= 2 

		#4
		x = (self.gc7(features, adj, F.relu))      
		x = (self.gc8(x, adj, F.relu))
		features = features +x 
		features/= 2 

		#5
		x = (self.gc9(features, adj, F.relu))      
		x = (self.gc10(x, adj, F.relu))
		features = features +x 
		features/= 2 

		#6
		x = (self.gc11(features, adj, F.relu))      
		x = (self.gc12(x, adj, F.relu))
		features = features +x 
		features/= 2 

		#7
		x = (self.gc13(features, adj, F.relu))      
		
		features = features +x 
		features/= 2 

		coords = (self.gc15(features, adj,lambda x: x ))
		return features, coords 

class BatchMeshDeformationBlock(nn.Module):
	def __init__(self, input_features,verts, hidden = 192, output_features = 3):
		super(BatchMeshDeformationBlock, self).__init__()
		self.gc1 = Batch_Image_ZERON_GCNGCN(input_features, hidden)
		self.gc2 = Batch_Image_ZERON_GCNGCN(hidden, hidden)
		self.gc3 = Batch_Image_ZERON_GCNGCN(hidden , hidden)
		self.gc4 = Batch_Image_ZERON_GCNGCN(hidden, hidden)
		self.gc5 = Batch_Image_ZERON_GCNGCN(hidden , hidden)
		self.gc6 = Batch_Image_ZERON_GCNGCN(hidden, hidden)
		self.gc7 = Batch_Image_ZERON_GCNGCN(hidden , hidden)
		self.gc8 = Batch_Image_ZERON_GCNGCN(hidden, hidden)
		self.gc9 = Batch_Image_ZERON_GCNGCN(hidden , hidden)
		self.gc10 = Batch_Image_ZERON_GCNGCN(hidden, hidden)
		self.gc11 = Batch_Image_ZERON_GCNGCN(hidden , hidden)
		self.gc12 = Batch_Image_ZERON_GCNGCN(hidden, hidden)
		self.gc13 = Batch_Image_ZERON_GCNGCN(hidden , hidden)
		self.gc15 = Batch_Image_ZERON_GCNGCN(hidden,  output_features)
		self.hidden = hidden

		self.bn1 = nn.BatchNorm1d(verts)
		self.bn2 = nn.BatchNorm1d(verts)
		self.bn3 = nn.BatchNorm1d(verts)
		self.bn4 = nn.BatchNorm1d(verts)
		self.bn5 = nn.BatchNorm1d(verts)
		self.bn6 = nn.BatchNorm1d(verts)
		self.bn7 = nn.BatchNorm1d(verts)
		self.bn8 = nn.BatchNorm1d(verts)
		self.bn9 = nn.BatchNorm1d(verts)
		self.bn10 = nn.BatchNorm1d(verts)
		self.bn11 = nn.BatchNorm1d(verts)
		self.bn12 = nn.BatchNorm1d(verts)
		self.bn13 = nn.BatchNorm1d(verts)
		self.bn14 = nn.BatchNorm1d(verts)

	def forward(self, features, pooled , adj):

		full_features = torch.cat((features,pooled), dim = -1)
		
		x = (self.gc1(full_features, adj, lambda x: x))
		x = F.relu(self.bn1(x))      
		x = (self.gc2(x, adj, lambda x: x))
		x = F.relu(self.bn2(x))

		x = full_features[:,:, :self.hidden] +x
		features = x 
		features/= 2 
		
		#2
		x = (self.gc3(features, adj, lambda x: x))
		x = F.relu(self.bn3(x))      
		x = (self.gc4(x, adj, lambda x: x))
		x = F.relu(self.bn4(x))
		features = features +x 
		features/= 2 
		#3
		x = (self.gc5(features, adj, lambda x: x))
		x = F.relu(self.bn5(x))      
		x = (self.gc6(x, adj, lambda x: x))
		x = F.relu(self.bn6(x))
		features = features +x 
		features/= 2 

		#4
		x = (self.gc7(features, adj, lambda x: x))
		x = F.relu(self.bn7(x))      
		x = (self.gc8(x, adj, lambda x: x))
		x = F.relu(self.bn8(x))
		features = features +x 
		features/= 2 

		#5
		x = (self.gc9(features, adj, lambda x: x))
		x = F.relu(self.bn9(x))      
		x = (self.gc10(x, adj, lambda x: x))
		x = F.relu(self.bn10(x))
		features = features +x 
		features/= 2 

		#6
		x = (self.gc11(features, adj, lambda x: x))
		x = F.relu(self.bn11(x))      
		x = (self.gc12(x, adj, lambda x: x))
		x = F.relu(self.bn12(x))
		features = features +x 
		features/= 2 

		#7
		x = (self.gc13(features, adj, lambda x: x))
		x = F.relu(self.bn13(x))      
		
		features = features +x 
		features/= 2 

		coords = (self.gc15(features, adj,lambda x: x ))
		return features, coords 

class MeshEncoder(nn.Module):
	def __init__(self, latent_length):
		super(MeshEncoder, self).__init__()
		self.h1 = ZERON_GCN(3, 60)
		self.h21 = ZERON_GCN(60, 60)
		self.h22 = ZERON_GCN(60, 60)
		self.h23 = ZERON_GCN(60, 60)
		self.h24 = ZERON_GCN(60,120)
		self.h3 = ZERON_GCN(120, 120)
		self.h4 = ZERON_GCN(120, 120)
		self.h41 = ZERON_GCN(120, 150)
		self.h5 = ZERON_GCN(150, 200)
		self.h6 = ZERON_GCN(200, 210)
		self.h7 = ZERON_GCN(210, 250)
		self.h8 = ZERON_GCN(250, 300)
		self.h81 = ZERON_GCN(300, 300)
		self.h9 = ZERON_GCN(300, 300)
		self.h10 = ZERON_GCN(300, 300)
		self.h11 = ZERON_GCN(300, 300)
		self.reduce = GCNMax(300,latent_length)	
	def resnet( self, features, res):
		temp = features[:,:res.shape[1]]
		temp = temp + res
		features = torch.cat((temp,features[:,res.shape[1]:]), dim = 1)
		return features, features

	def forward(self, positions,  adj, play = False):
		# print positions[:5, :5]
		res = positions
		features = self.h1(positions, adj, F.elu)
		features = self.h21(features, adj, F.elu)
		features = self.h22(features, adj, F.elu)
		features = self.h23(features, adj, F.elu)
		features = self.h24(features, adj, F.elu)
		features = self.h3(features, adj, F.elu)
		features = self.h4(features, adj, F.elu)
		features = self.h41(features, adj, F.elu)
		features = self.h5(features, adj, F.elu)
		features = self.h6(features, adj, F.elu)
		features = self.h7(features, adj, F.elu)
		features = self.h8(features, adj, F.elu)
		features = self.h81(features, adj, F.elu)
		features = self.h9(features, adj, F.elu)
		features = self.h10(features, adj, F.elu)
		features = self.h11(features, adj, F.elu)
			

		latent 	 = self.reduce(features , adj, F.elu)
		
		return latent


class Decoder(nn.Module): 
	def __init__(self, latent_length): 
		super(Decoder, self).__init__()
		self.fully = torch.nn.Sequential(
			  torch.nn.Linear(latent_length, 512)
			)

		self.model = torch.nn.Sequential(
			torch.nn.ConvTranspose3d( 64, 64, 4, stride=2, padding=(1, 1, 1), ), 
			nn.BatchNorm3d(64),
			nn.ELU(inplace=True),

			torch.nn.ConvTranspose3d( 64, 64, 4, stride=2, padding=(1, 1, 1)), 
			nn.BatchNorm3d(64),
			nn.ELU(inplace=True),

			torch.nn.ConvTranspose3d( 64, 32, 4, stride=2, padding=(1, 1, 1)), 
			nn.BatchNorm3d(32),
			nn.ELU(inplace=True),

			torch.nn.ConvTranspose3d( 32, 8, 4, stride=2, padding=(1, 1, 1)), 
			nn.BatchNorm3d(8),
			nn.ELU(inplace=True),

			nn.Conv3d(8, 1, (3, 3, 3), stride=1, padding=(1, 1, 1))
			)


	def forward(self, latent):
		decode = self.fully(latent).view(-1,64, 2, 2,2)
		decode = self.model(decode).reshape(-1,32,32,32)
		voxels = F.sigmoid(decode)
		return voxels


class BatchMeshEncoder(nn.Module):
	def __init__(self, latent_length):
		super(BatchMeshEncoder, self).__init__()
		self.h1 = BatchZERON_GCN(3, 60)
		self.h21 = BatchZERON_GCN(60, 60)
		self.h22 = BatchZERON_GCN(60, 60)
		self.h23 = BatchZERON_GCN(60, 60)
		self.h24 = BatchZERON_GCN(60,120)
		self.h3 = BatchZERON_GCN(120, 120)
		self.h4 = BatchZERON_GCN(120, 120)
		self.h41 = BatchZERON_GCN(120, 150)
		self.h5 = BatchZERON_GCN(150, 200)
		self.h6 = BatchZERON_GCN(200, 210)
		self.h7 = BatchZERON_GCN(210, 250)
		self.h8 = BatchZERON_GCN(250, 300)
		self.h81 = BatchZERON_GCN(300, 300)
		self.h9 = BatchZERON_GCN(300, 300)
		self.h10 = BatchZERON_GCN(300, 300)
		self.h11 = BatchZERON_GCN(300, 300)
		self.reduce = BatchGCNMax(300,latent_length)	
	def resnet( self, features, res):
		temp = features[:,:res.shape[1]]
		temp = temp + res
		features = torch.cat((temp,features[:,res.shape[1]:]), dim = 1)
		return features, features

	def forward(self, positions,  adj, play = False):
		# print positions[:5, :5]
		res = positions
		features = self.h1(positions, adj, F.elu)
		features = self.h21(features, adj, F.elu)
		features = self.h22(features, adj, F.elu)
		features = self.h23(features, adj, F.elu)
		features = self.h24(features, adj, F.elu)
		features = self.h3(features, adj, F.elu)
		features = self.h4(features, adj, F.elu)
		features = self.h41(features, adj, F.elu)
		features = self.h5(features, adj, F.elu)
		features = self.h6(features, adj, F.elu)
		features = self.h7(features, adj, F.elu)
		features = self.h8(features, adj, F.elu)
		features = self.h81(features, adj, F.elu)
		features = self.h9(features, adj, F.elu)
		features = self.h10(features, adj, F.elu)
		features = self.h11(features, adj, F.elu)
			

		latent 	 = self.reduce(features , adj, F.elu)
		
		return latent
back to top