https://github.com/brownvc/deep-synth
Revision b800e11290b763b58e7d3b30329769a7b77cd12a authored by kwang-ether on 14 June 2019, 23:53:57 UTC, committed by kwang-ether on 14 June 2019, 23:53:57 UTC
1 parent 79eaa7f
Tip revision: b800e11290b763b58e7d3b30329769a7b77cd12a authored by kwang-ether on 14 June 2019, 23:53:57 UTC
remove csv
remove csv
Tip revision: b800e11
location_train.py
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from models import *
from location_dataset import LocationDataset
import numpy as np
import math
import utils
parser = argparse.ArgumentParser(description='Location Training with Auxillary Tasks')
parser.add_argument('--data-dir', type=str, default="bedroom", metavar='S')
parser.add_argument('--num-workers', type=int, default=6, metavar='N')
parser.add_argument('--last-epoch', type=int, default=-1, metavar='N') #If positive, use saved epoch
parser.add_argument('--train-size', type=int, default=6400, metavar='N')
parser.add_argument('--save-dir', type=str, default="train/bedroom", metavar='S')
parser.add_argument('--ablation', type=str, default=None, metavar='S')
parser.add_argument('--lr', type=float, default=0.001, metavar='N')
parser.add_argument('--eps', type=float, default=1e-6, metavar='N')
parser.add_argument('--p-auxiliary', type=float, default=0.0, metavar='N')
parser.add_argument('--use-count', action='store_true', default=False) #Use category count if true
parser.add_argument('--no-penalty', action='store_true', default=False) #If True, L_Global is not used
parser.add_argument('--progressive-p', action='store_true', default=False) #If True, start with lower p_auxiliary and gradually increase
args = parser.parse_args()
save_dir = args.save_dir
utils.ensuredir(save_dir)
batch_size = 16
with open(f"data/{args.data_dir}/final_categories_frequency", "r") as f:
lines = f.readlines()
num_categories = len(lines)-2
if args.ablation is None:
num_input_channels = num_categories+9
elif args.ablation == "basic":
num_input_channels = 7
elif args.ablation == "depth":
num_input_channels = 2
else:
raise NotImplementedError
logfile = open(f"{save_dir}/log_location.txt", 'w')
def LOG(msg):
print(msg)
logfile.write(msg + '\n')
logfile.flush()
LOG('Building model...')
model = resnet101(num_classes=num_categories+3, num_input_channels=num_input_channels, use_fc=False)
if args.use_count:
fc = FullyConnected(2048 + num_categories, num_categories+3)
else:
fc = FullyConnected(2048, num_categories+3)
cross_entropy = nn.CrossEntropyLoss()
softmax = nn.Softmax(dim=1)
LOG('Converting to CUDA...')
model.cuda()
fc.cuda()
cross_entropy.cuda()
softmax.cuda()
LOG('Building dataset...')
train_dataset = LocationDataset(
data_root_dir = utils.get_data_root_dir(),
data_dir = args.data_dir,
scene_indices = (0, args.train_size),
p_auxiliary = args.p_auxiliary,
ablation = args.ablation
)
#Size of validation set is 160 by default
validation_dataset = LocationDataset(
data_root_dir = utils.get_data_root_dir(),
data_dir = args.data_dir,
scene_indices = (args.train_size, args.train_size+160),
seed = 42,
p_auxiliary = 0, #Only tests positive examples in validation
ablation = args.ablation
)
LOG('Building data loader...')
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size = batch_size,
num_workers = args.num_workers,
shuffle = True
)
validation_loader = torch.utils.data.DataLoader(
validation_dataset,
batch_size = batch_size,
num_workers = 0,
shuffle = True
)
LOG('Building optimizer...')
optimizer = optim.Adam(list(model.parameters())+list(fc.parameters()),
lr = args.lr,
betas = (0.9,0.999),
eps = args.eps
)
if args.last_epoch < 0:
load = False
starting_epoch = 0
else:
load = True
last_epoch = args.last_epoch
if load:
LOG('Loading saved models...')
model.load_state_dict(torch.load(f"{save_dir}/location_{last_epoch}.pt"))
fc.load_state_dict(torch.load(f"{save_dir}/location_fc_{last_epoch}.pt"))
optimizer.load_state_dict(torch.load(f"{save_dir}/location_optim_backup.pt"))
starting_epoch = last_epoch + 1
current_epoch = starting_epoch
num_seen = 0
if args.progressive_p:
if current_epoch <= 30:
train_dataset.p_auxiliary = 0.0
if current_epoch > 30:
train_dataset.p_auxiliary = 0.5
if current_epoch > 60:
train_dataset.p_auxiliary = 0.7
if current_epoch > 90:
train_dataset.p_auxiliary = 0.9
if current_epoch > 120:
train_dataset.p_auxiliary = 0.95
model.train()
LOG(f'=========================== Epoch {current_epoch} ===========================')
def train():
global num_seen, current_epoch
for batch_idx, (data, t, existing, penalty) \
in enumerate(train_loader):
data, t = data.cuda(), t.cuda()
existing, penalty = existing.cuda(), penalty.cuda()
optimizer.zero_grad()
o_conv = model(data)
if args.use_count:
o_conv = torch.cat([o_conv, existing], 1)
o = fc(o_conv)
loss = cross_entropy(o,t)
if not args.no_penalty:
o_s = softmax(o)[:,0:num_categories]
l_penalty = (o_s * penalty).sum()
if not args.no_penalty:
loss += l_penalty
loss.backward()
optimizer.step()
num_seen += batch_size
if num_seen % 800 == 0:
LOG(f'Examples {num_seen}/10000')
if num_seen % 10000 == 0:
LOG('Validating')
validate()
model.train()
fc.train()
num_seen = 0
current_epoch += 1
LOG(f'=========================== Epoch {current_epoch} ===========================')
LOG(f'{train_dataset.p_auxiliary}')
if current_epoch % 5 == 0:
torch.save(model.state_dict(), f"{save_dir}/location_{current_epoch}.pt")
torch.save(fc.state_dict(), f"{save_dir}/location_fc_{current_epoch}.pt")
torch.save(optimizer.state_dict(), f"{save_dir}/location_optim_backup.pt")
if args.progressive_p:
if current_epoch <= 30:
train_dataset.p_auxiliary = 0.0
if current_epoch > 30:
train_dataset.p_auxiliary = 0.5
if current_epoch > 60:
train_dataset.p_auxiliary = 0.7
if current_epoch > 90:
train_dataset.p_auxiliary = 0.9
if current_epoch > 120:
train_dataset.p_auxiliary = 0.95
def validate():
model.eval()
fc.eval()
total_loss = 0
total_accuracy = 0
for batch_idx, (data, t, existing, penalty) \
in enumerate(validation_loader):
with torch.no_grad():
data, t = data.cuda(), t.cuda()
existing = existing.cuda()
optimizer.zero_grad()
o_conv = model(data)
if args.use_count:
o_conv = torch.cat([o_conv, existing], 1)
o = fc(o_conv)
l = cross_entropy(o,t)
total_loss += l.cpu().data.numpy()
output = softmax(o)
outputs = output.cpu().data.numpy()
targets = t.cpu().data.numpy()
predictions = np.argmax(outputs, axis=1)
num_correct = np.sum(predictions == targets)
total_accuracy += num_correct / batch_size
LOG(f'Loss: {total_loss/10}, Accuracy: {total_accuracy/10}')
while True:
train()

Computing file changes ...