Raw File
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from continue_dataset import ShouldContinueDataset
import math
import models
from models import *
import os
import os.path
import numpy as np
import utils

"""
Train the network that predicts whether we should continue adding objects
"""
parser = argparse.ArgumentParser(description='Should Continue')
parser.add_argument('--data-dir', type=str, default="bedroom", metavar='S')
parser.add_argument('--num-workers', type=int, default=6, metavar='N')
parser.add_argument('--last-epoch', type=int, default=-1, metavar='N')
parser.add_argument('--train-size', type=int, default=6400, metavar='N')
parser.add_argument('--save-dir', type=str, default="train/bedroom", metavar='S')
parser.add_argument('--ablation', type=str, default=None, metavar='S')
parser.add_argument('--lr', type=float, default=0.001, metavar='N')
parser.add_argument('--eps', type=float, default=1e-6, metavar='N')
parser.add_argument('--use-count', action='store_true', default=False)
args = parser.parse_args()

save_dir = args.save_dir
utils.ensuredir(save_dir)
learning_rate = args.lr
batch_size = 16

with open(f"data/{args.data_dir}/final_categories_frequency", "r") as f:
    lines = f.readlines()
num_categories = len(lines)-2

if args.ablation is None:
    num_input_channels = num_categories+8
elif args.ablation == "basic":
    num_input_channels = 6
elif args.ablation == "depth":
    num_input_channels = 1
else:
    raise NotImplementedError

logfile = open(f"{save_dir}/log_continue.txt", 'w')
def LOG(msg):
    print(msg)
    logfile.write(msg + '\n')
    logfile.flush()

LOG('Building model...')
if args.use_count:
    model = models.resnet101(num_classes=2, num_input_channels=num_input_channels, use_fc=False)
    fc = FullyConnected(2048+num_categories, 2)
else:
    model = models.resnet101(num_classes=2, num_input_channels=num_categories+8)

loss = nn.CrossEntropyLoss()
softmax = nn.Softmax(dim=1)

LOG('Converting to CUDA...')
model.cuda()
if args.use_count:
    fc.cuda()
loss.cuda()
softmax.cuda()

LOG('Building dataset...')
train_dataset = ShouldContinueDataset(
    data_root_dir = utils.get_data_root_dir(),
    data_dir = args.data_dir,
    scene_indices = (0, args.train_size),
    num_per_epoch = 1,
    complete_prob = 0.5,
    ablation = args.ablation
)

validation_dataset = ShouldContinueDataset(
    data_root_dir = utils.get_data_root_dir(),
    data_dir = args.data_dir,
    scene_indices = (args.train_size, args.train_size+160),
    num_per_epoch = 1,
    complete_prob = 0.5,
    seed = 42,
    ablation = args.ablation
)

LOG('Building data loader...')
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    num_workers = args.num_workers,
    shuffle = True
)

validation_loader = torch.utils.data.DataLoader(
    validation_dataset,
    batch_size = batch_size,
    num_workers = 0,
    shuffle = True
)

LOG('Building optimizer...')
optimizer = optim.Adam(list(model.parameters())+list(fc.parameters()),
    lr = learning_rate,
    betas = (0.9,0.999),
    eps = args.eps
)

if args.last_epoch < 0:
    load = False
    starting_epoch = 0
else:
    load = True
    last_epoch = args.last_epoch

if load:
    LOG('Loading saved models...')
    model.load_state_dict(torch.load(f"{save_dir}/continue_{last_epoch}.pt"))
    if args.use_count:
        fc.load_state_dict(torch.load(f"{save_dir}/continue_fc_{last_epoch}.pt"))
    optimizer.load_state_dict(torch.load(f"{save_dir}/continue_optim_backup.pt"))
    starting_epoch = last_epoch + 1

current_epoch = starting_epoch
num_seen = 0

model.train()
LOG(f'=========================== Epoch {current_epoch} ===========================')

def train():
    global num_seen, current_epoch
    for batch_idx, (data, target, existing) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()

        if args.use_count:
            existing = existing.cuda()
            existing = Variable(existing)

        optimizer.zero_grad()
        if args.use_count:
            o_conv = model(data)
            o_conv = torch.cat([o_conv, existing], 1)
            output = fc(o_conv)
        else:
            output = model(data)

        loss_val = loss(output, target)
        loss_val.backward()
        optimizer.step()

        num_seen += batch_size
        if num_seen % 800 == 0:
            LOG(f'Examples {num_seen}/10000')
        if num_seen % 10000 == 0:
            LOG('Validating')
            validate()
            model.train()
            fc.train()
            num_seen = 0
            current_epoch += 1
            LOG(f'=========================== Epoch {current_epoch} ===========================')
            if current_epoch % 10 == 0:
                torch.save(model.state_dict(), f"{save_dir}/continue_{current_epoch}.pt")
                if args.use_count:
                    torch.save(fc.state_dict(), f"{save_dir}/continue_fc_{current_epoch}.pt")
                torch.save(optimizer.state_dict(), f"{save_dir}/continue_optim_backup.pt")

def validate():
    model.eval()
    fc.eval()
    total_loss = 0
    total_accuracy = 0
    for _, (data, target, existing) in enumerate(validation_loader):
        with torch.no_grad():
            data, target = data.cuda(), target.cuda()
            if args.use_count:
                existing = existing.cuda()
                existing = Variable(existing)

            optimizer.zero_grad()

            if args.use_count:
                o_conv = model(data)
                o_conv = torch.cat([o_conv, existing], 1)
                output = fc(o_conv)
            else:
                output = model(data)

            loss_val = loss(output, target)
            total_loss += loss_val.cpu().data.numpy()

            output = softmax(output)
            outputs = output.cpu().data.numpy()
            targets = target.cpu().data.numpy()
            predictions = np.argmax(outputs, axis=1)
            num_correct = np.sum(predictions == targets)
            total_accuracy += num_correct / batch_size

    LOG(f'Loss: {total_loss/10}, Accuracy: {total_accuracy/10}')

while True:
    train()
back to top