from data import * from utils.augmentations import SSDAugmentation from utils.logger import logger from utils.helper import check_dir, resume_checkpoint, save_checkpoint from layers.modules import MultiBoxLoss, ImageLevelLoss, HalfMultiBoxLoss from ssd import build_ssd # from ssd_consistency import build_ssd_con from csd import build_ssd_con import os import sys import time import torch from torch.autograd import Variable from torch.utils.data import RandomSampler import torch.nn.functional as F import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn import torch.nn.init as init import torch.utils.data as data import numpy as np import argparse import math import copy import shutil from tensorboardX import SummaryWriter def str2bool(v): return v.lower() in ("yes", "true", "t", "1") parser = argparse.ArgumentParser( description='Single Shot MultiBox Detector Training With Pytorch') train_set = parser.add_mutually_exclusive_group() parser.add_argument('--dataset', default='VOC07', choices=['VOC07', 'VOC12', "COCO"], type=str, help='VOC07, VOC12 or COCO') parser.add_argument('--model', default=300, choices=[300, 512], type=int, help='300 or 512') parser.add_argument('--dataset_root', default=VOC_ROOT, help='Dataset root directory path') parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='Pretrained base model') parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training') parser.add_argument('--resume', default=True, type=str2bool, # None 'weights/ssd300_COCO_80000.pth' help='Checkpoint state_dict file to resume training from') parser.add_argument('--start_iter', default=0, type=int, help='Resume training at this iter') parser.add_argument('--image_level_supervision', default=1, type=float, help='whether use image level supervision') parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') parser.add_argument('--cuda', default=True, type=str2bool, help='Use CUDA to train model') parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, help='Momentum value for optim') parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD') parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD') parser.add_argument('--ramp_num', default=2, type=float, help='ramp_num') parser.add_argument('--half_weight', default=1, type=float, help='half_weight') parser.add_argument('--save_folder', default='weights/', help='Directory for saving checkpoint models') parser.add_argument('--supervise_percent', default=0.5, type=float, help='supervise_percent') parser.add_argument('--supervise_num', default=0, type=int, help='supervise_num, superior than supervise_percent') parser.add_argument('--semi', default=True, type=str2bool, help='semi') parser.add_argument('--label', default=False, type=str2bool, help='label') parser.add_argument('--super_iter', default=1, type=int, help='super_iter') parser.add_argument('--add_anno', default="", type=str, help='') parser.add_argument('--batch_super_times', default=1, type=int, help='batch_super_times') args = parser.parse_args() if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: logger.info("WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) image_level_loss_weight = args.image_level_supervision image_level_str = str(image_level_loss_weight) if image_level_loss_weight == 1.0: image_level_str = "1" if image_level_loss_weight == 0.0: image_level_str = "0" semi_loss_weight = float(int(args.semi)) supervised = str(args.supervise_percent) if args.supervise_num: supervised = str(args.supervise_num) suffix = "" suffix ="_conf" if args.semi is True else "" label_name = "" if args.label: label_name = "_label" half_weight = args.half_weight # checkpoint file checkpoint_file = os.path.join(args.save_folder, "FINETUNE/data_{}_imagelevel_{}_lr_{}_ramp_{}_halfw_{}_superpercent_{}_super_time_{}_batch_{}{}{}{}" \ .format(args.dataset, image_level_str, args.lr, args.ramp_num, half_weight, supervised, args.batch_super_times, args.batch_size, suffix, label_name, args.add_anno)) check_dir(checkpoint_file) tb_log = os.path.join(checkpoint_file, "log") if os.path.exists(tb_log) and args.resume is False: continue_flag = "n" logger.info("tb_log exists, but resume is set as false, you want to remove tb_log? y/[n]") continue_flag = input() if continue_flag == "y": shutil.rmtree(tb_log) else: logger.info("exit") exit() tb_writer = SummaryWriter(tb_log) def train(): if args.dataset == 'COCO': args.dataset_root = COCO17_ROOT if args.model == 512: cfg = coco512 elif args.model == 300: cfg = coco300 else: raise ValueError("unsupported model") supervised_dataset = COCO17Detection(root=args.dataset_root, supervise_percent=args.supervise_percent, supervise_num = args.supervise_num, only_supervise=True, extracting_label=args.label, finetune=True, add_anno=args.add_anno, transform=SSDAugmentation(cfg['min_dim'], MEANS)) if args.supervise_percent == 0: supervised_dataset.set_only_supervise(False) supervised_batch = args.batch_size batch_sampler = VOCBatchSampler(RandomSampler(supervised_dataset), batch_size=supervised_batch, drop_last=True, super_threshold=supervised_dataset.supervise_num, supervise_percent=supervised_dataset.supervise_percent, only_supervise=True, batch_super_times=args.batch_super_times) finish_flag = True while(finish_flag): ssd_net = build_ssd_con('train', cfg['min_dim'], cfg['num_classes']) net = ssd_net if args.cuda: net = torch.nn.DataParallel(ssd_net) cudnn.benchmark = True init_model = os.path.join(args.save_folder, "FINETUNE", "ckpt_400000.pth") checkpoint_model = torch.load(init_model) model_state_dict = checkpoint_model["model_state_dict"] optimizer_state_dict = checkpoint_model["optimizer_state_dict"] args.start_iter = 0 net.load_state_dict(model_state_dict) if args.cuda: net = net.cuda() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.load_state_dict(optimizer_state_dict) criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, args.cuda) image_level_criterion = ImageLevelLoss(cfg['num_classes'], args.cuda) half_criterion = HalfMultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, args.cuda) conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda() adjust_learning_rate(optimizer, args.gamma, 0) net.train() # loss counters loc_loss = 0 conf_loss = 0 epoch = 0 supervised_flag = 0 logger.info('Loading the dataset...') logger.info('Training SSD on: {}'.format(args.dataset)) logger.info('Using the specified args:') logger.info(args) step_index = 0 total_un_iter_num = 0 #unsupervised_batch = args.batch_size - supervised_batch #data_shuffle = 0 supervised_dataset.set_only_supervise((not args.semi)) batch_sampler.set_only_supervise((not args.semi)) supervised_data_loader = data.DataLoader(supervised_dataset, batch_sampler=batch_sampler, num_workers=4, collate_fn=finetune_detection_collate, pin_memory=True) batch_iterator = iter(supervised_data_loader) scale = args.batch_size / 32 for iteration in range(args.start_iter, int(200000 / scale)): # # modify learning rate # if iteration in cfg['lr_steps']: # # step_index += 1 # step_index = cfg["lr_steps"].index(int(iteration * scale)) + 1 # adjust_learning_rate(optimizer, args.gamma, step_index) # get data try: images, targets, semis, image_level_target, det_res = \ next(batch_iterator) except StopIteration: batch_iterator = iter(supervised_data_loader) images, targets, semis, image_level_target, det_res = \ next(batch_iterator) print("$$$$$$$$$$$$$$$$$$change data") dd = np.array(semis) if args.cuda: images = Variable(images.cuda()) targets = [Variable(ann.cuda(), volatile=True) for ann in targets] det_res = [Variable(ann.cuda(), volatile=True) for ann in det_res] image_level_target = Variable(image_level_target.cuda()) else: images = Variable(images) targets = [Variable(ann, volatile=True) for ann in targets] det_res = [Variable(ann, volatile=True) for ann in det_res] image_level_target = Variable(image_level_target) # forward t0 = time.time() # import IPython; IPython.embed(); exit() out, conf, conf_flip, loc, loc_flip = net(images) loss_image_level, margin, average_num = image_level_criterion(out[1], image_level_target) sup_image_binary_index = np.zeros([len(semis),1]) half_image_index = [] full_target = [] half_target = [] half_det = [] for super_image in range(len(semis)): if(int(semis[super_image])==1): sup_image_binary_index[super_image] = 1 else: sup_image_binary_index[super_image] = 0 if len(targets[super_image]) > 0: half_target.append(targets[super_image]) half_image_index.append(super_image) half_det.append(det_res[super_image]) # if(int(semis[len(semis)-1-super_image])==0): # del targets[len(semis)-1-super_image] sup_image_index = np.where(sup_image_binary_index == 1)[0] unsup_image_index = np.where(sup_image_binary_index == 0)[0] half_image_index = np.array(half_image_index) for bi in sup_image_index: full_target.append(targets[bi]) print("sup, unsup, half", (dd==1).sum(), (dd==0).sum(), len(half_target)) loc_data, conf_data, priors = out if (len(sup_image_index) != 0): full_loc_data = loc_data[sup_image_index,:,:] full_conf_data = conf_data[sup_image_index,:,:] output = ( full_loc_data, full_conf_data, priors ) loss_l = Variable(torch.cuda.FloatTensor([0])) loss_c = Variable(torch.cuda.FloatTensor([0])) if(len(sup_image_index)!=0): loss_l, loss_c = criterion(output, full_target) if (len(half_det) != 0): half_loc_data = loc_data[half_image_index,:,:] half_conf_data = conf_data[half_image_index,:,:] half_output = ( half_loc_data, half_conf_data, priors ) half_loss_l = Variable(torch.cuda.FloatTensor([0])) half_loss_c = Variable(torch.cuda.FloatTensor([0])) if(len(half_target)!=0): half_loss_l, half_loss_c = half_criterion(half_output, half_target, half_det) # sampling = True if(args.semi is True): logger.info("semi") conf_class = conf[:,:,1:].clone() background_score = conf[:, :, 0].clone() each_val, each_index = torch.max(conf_class, dim=2) mask_val = each_val > background_score mask_val = mask_val.data mask_conf_index = mask_val.unsqueeze(2).expand_as(conf) mask_loc_index = mask_val.unsqueeze(2).expand_as(loc) conf_mask_sample = conf.clone() loc_mask_sample = loc.clone() conf_sampled = conf_mask_sample[mask_conf_index].view(-1, cfg["num_classes"]) loc_sampled = loc_mask_sample[mask_loc_index].view(-1, 4) conf_mask_sample_flip = conf_flip.clone() loc_mask_sample_flip = loc_flip.clone() conf_sampled_flip = conf_mask_sample_flip[mask_conf_index].view(-1, cfg["num_classes"]) loc_sampled_flip = loc_mask_sample_flip[mask_loc_index].view(-1, 4) # print("mask_val.sum: ", mask_val.sum()) if(args.semi and mask_val.sum()>0): ## JSD !!!!!1 conf_sampled_flip = conf_sampled_flip + 1e-7 conf_sampled = conf_sampled + 1e-7 consistency_conf_loss_a = conf_consistency_criterion(conf_sampled.log(), conf_sampled_flip.detach()).sum(-1).mean() consistency_conf_loss_b = conf_consistency_criterion(conf_sampled_flip.log(), conf_sampled.detach()).sum(-1).mean() consistency_conf_loss = consistency_conf_loss_a + consistency_conf_loss_b ## LOC LOSS consistency_loc_loss_x = torch.mean(torch.pow(loc_sampled[:, 0] + loc_sampled_flip[:, 0], exponent=2)) consistency_loc_loss_y = torch.mean(torch.pow(loc_sampled[:, 1] - loc_sampled_flip[:, 1], exponent=2)) consistency_loc_loss_w = torch.mean(torch.pow(loc_sampled[:, 2] - loc_sampled_flip[:, 2], exponent=2)) consistency_loc_loss_h = torch.mean(torch.pow(loc_sampled[:, 3] - loc_sampled_flip[:, 3], exponent=2)) consistency_loc_loss = torch.div( consistency_loc_loss_x + consistency_loc_loss_y + consistency_loc_loss_w + consistency_loc_loss_h, 4) else: # logger.info("not semi@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") consistency_conf_loss = Variable(torch.cuda.FloatTensor([0])) consistency_loc_loss = Variable(torch.cuda.FloatTensor([0])) consistency_loss = torch.div(consistency_conf_loss,2) + consistency_loc_loss ramp_weight = rampweight(iteration) consistency_loss = torch.mul(consistency_loss, ramp_weight) # print("sup_image_index", len(sup_image_index)) if(supervised_flag >= 1): loss = loss_l + loss_c + consistency_loss + \ image_level_loss_weight * loss_image_level + \ half_weight * (half_loss_l + half_loss_c) else: if(len(sup_image_index)==0): loss = consistency_loss + image_level_loss_weight * loss_image_level + \ half_weight * (half_loss_l + half_loss_c) # continue else: loss = loss_l + loss_c + consistency_loss + \ image_level_loss_weight * loss_image_level+ \ half_weight * (half_loss_l + half_loss_c) if(loss.data>0): optimizer.zero_grad() loss.backward() optimizer.step() t1 = time.time() if(len(sup_image_index)==0): loss_l.data = Variable(torch.cuda.FloatTensor([0])) loss_c.data = Variable(torch.cuda.FloatTensor([0])) loss_image_level.data = Variable(torch.cuda.FloatTensor([0])) logger.info("len(sup_image_index)==0") else: loc_loss += loss_l.data # [0] conf_loss += loss_c.data # [0] if iteration % 10 == 0: logger.info('timer: %.4f sec.' % (t1 - t0)) logger.info('iter ' + repr(iteration) + ' || Loss: %.4f || consistency_loss : %.4f ||' % (loss.data, consistency_loss.data)) logger.info('loss: %.4f , loss_c: %.4f , loss_l: %.4f , loss_con: %.4f, lr : %.4f, super_len : %d\n' % (loss.data, loss_c.data, loss_l.data, consistency_loss.data,float(optimizer.param_groups[0]['lr']),len(sup_image_index))) if iteration % 100 == 0: tb_writer.add_scalar("Train/loss", loss.data, iteration) tb_writer.add_scalar("Train/loss_c", loss_c.data, iteration) tb_writer.add_scalar("Train/loss_l", loss_l.data, iteration) tb_writer.add_scalar("Train/half_loss_c", half_loss_c.data, iteration) tb_writer.add_scalar("Train/half_loss_l", half_loss_l.data, iteration) tb_writer.add_scalar("Train/loss_image_level", loss_image_level.data, iteration) tb_writer.add_scalar("Train/loss_con", consistency_loss.data, iteration) tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]['lr']), iteration) tb_writer.add_scalar("Train/ramp_weight", float(ramp_weight), iteration) tb_writer.add_scalar("Train/sup_len", float(len(sup_image_index)), iteration) tb_writer.add_scalar("Train/margin", float(margin), iteration) tb_writer.add_scalar("Train/average_num", float(average_num), iteration) if(float(loss)>10000): print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") # import IPython; IPython.embed(); exit() break if iteration != 0 and (iteration+1) % 1000 == 0: # logger.info('Saving state, iter:', iteration) # torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' + # repr(iteration+1) + '.pth') save_checkpoint((iteration+1), net, optimizer, checkpoint_file) # torch.save(ssd_net.state_dict(), args.save_folder + '' + args.dataset + '.pth') logger.info('-------------------------------\n') logger.info(loss.data) logger.info('-------------------------------') if((iteration +1) == 200000): finish_flag = False def rampweight(iteration): scale = args.batch_size / 32 if args.dataset == "COCO": scale = scale / 4 * 1.2 scale = scale * 2 ramp_up_end = 32000 / scale ramp_down_start = 100000 / scale ramp_end = 120000 / scale step = 20000 / scale if(iterationramp_down_start): ramp_weight = math.exp(-12.5 * math.pow((1 - (ramp_end - iteration) / step),2)) else: ramp_weight = 1 if(iteration==0): ramp_weight = 0 return ramp_weight / args.ramp_num def adjust_learning_rate(optimizer, gamma, step): """Sets the learning rate to the initial LR decayed by 10 at every specified step # Adapted from PyTorch Imagenet example: # https://github.com/pytorch/examples/blob/master/imagenet/main.py """ lr = args.lr * (gamma ** (step)) for param_group in optimizer.param_groups: param_group['lr'] = lr def xavier(param): init.xavier_uniform(param) def weights_init(m): if isinstance(m, nn.Conv2d): xavier(m.weight.data) m.bias.data.zero_() if __name__ == '__main__': train()