Raw File
auto_encoder.py
from utils import *
from layers import * 
from models import *
from voxel  import voxel2obj
from torch.utils.data import DataLoader

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=40, help='Random seed.')
parser.add_argument('--epochs', type=int, default=150,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.0001,
                    help='Initial learning rate.')
parser.add_argument('--exp_id', type=str, default='Test',
                    help='The experiment name')


args = parser.parse_args()
batch_size = 16
latent_length = 50
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# data settings
images = 'data/images/*'
voxels = 'data/voxels/'
meshes = 'data/mesh_info/'
checkpoint_dir = "checkpoint/" +  args.exp_id +'/'
save_dir =  "plots/" +  args.exp_id +'/'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)


# load data
train_data = Voxel_loader(images, meshes, voxels, set_type = 'train')
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=8, collate_fn = train_data.collate)

valid_data = Voxel_loader(images, meshes, voxels, set_type = 'valid')
valid_loader = DataLoader(valid_data, batch_size=16, shuffle=False, num_workers=8, collate_fn = valid_data.collate)

# load models
encoder_mesh = MeshEncoder(latent_length)
decoder = Decoder(latent_length)

# pytorch it up 
decoder.cuda(), encoder_mesh.cuda()
params = list(decoder.parameters()) + list(encoder_mesh.parameters())   
optimizer = optim.Adam(params,lr=args.lr)


class Engine() :
    def __init__(self): 
        self.best = 1000
        self.epoch = 0 
        self.train_losses = []
        self.valid_losses = []

    def train(self):
        decoder.train(),  encoder_mesh.train()
        total_loss = 0
        iteration = 0 
        for batch in tqdm(train_loader): 
            voxel_gt = batch['voxels'].cuda()
            optimizer.zero_grad()    
            
            latent = None
            # pass through encoder
            for mesh, adj in zip(batch['verts'], batch['adjs']):
                mesh = mesh.cuda()
                adj = adj.cuda()
                if latent is None: 
                    latent = encoder_mesh(mesh, adj).unsqueeze(0)
                else: latent = torch.cat((latent,encoder_mesh(mesh,adj).unsqueeze(0))) 
          
          	# pass trough decoder 
            voxel_pred = decoder(latent)

            # calculate loss and optimize with respect to it 
            loss = torch.mean((voxel_pred- voxel_gt)**2 )
            loss.backward()
            optimizer.step()
            track_loss = loss.item()
            total_loss += track_loss

            # print info occasionally 
            if iteration % 20 ==0 : 
                message = f'Train Loss: Epoch: {self.epoch}, loss: {track_loss}, best: {self.best}'
                tqdm.write(message)
            iteration += 1 

        self.train_losses.append(total_loss / float(iteration))

   


    def validate(self): 
        decoder.eval(), encoder_mesh.eval()
        total_loss = 0
        iteration = 0 
        for batch in tqdm(valid_loader): 
            voxel_gt = batch['voxels'].cuda()
            optimizer.zero_grad()    
            
            latent = None
            # pass through encoder
            for mesh, adj in zip(batch['verts'], batch['adjs']):
                mesh = mesh.cuda()
                adj = adj.cuda()
                if latent is None: 
                    latent = encoder_mesh(mesh, adj).unsqueeze(0)
                else: latent = torch.cat((latent,encoder_mesh(mesh,adj).unsqueeze(0))) 
          
            # pass trough decoder 
            voxel_pred = decoder(latent)

            # calculate loss and optimize with respect to it 
            loss = torch.mean((voxel_pred- voxel_gt)**2 )
            track_loss = loss.item()
            total_loss += track_loss

            # print info occasionally 
            if iteration % 20 ==0 : 
                message = f'Valid Loss: Epoch: {self.epoch}, new: {total_loss / float(iteration + 1 )}, cur: {self.best}'
                tqdm.write(message)
            iteration += 1 

        self.valid_losses.append(total_loss / float(iteration))
      
          
    def save(self): 
        if self.valid_losses[-1] <= self.best:
            self.best = self.valid_losses[-1] 
            torch.save(decoder.state_dict(), checkpoint_dir + 'decoder')
            torch.save(encoder_mesh.state_dict(), checkpoint_dir + 'encoder')







trainer = Engine()



for epoch in range(args.epochs):
    trainer.epoch = epoch
    trainer.train()
    trainer.validate()
    trainer.save()

print ('Saving latent codes of all models')
encoder_mesh.load_state_dict(torch.load(checkpoint_dir + 'encoder'))
encoder_mesh.eval()
if not os.path.exists('data/latent/'):
    os.makedirs('data/latent/')
for batch in tqdm(train_loader):
    for v, a, n  in zip(batch['verts'], batch['adjs'], batch['names']):
        
        latent = encoder_mesh(v.cuda(), a.cuda() )
        np.save('data/latent/' + n + '_latent', latent.data.cpu().numpy())

        
     
back to top