Revision f37c96d2758baff99b20ae9081eedd8a8155a597 authored by xchhuang on 25 July 2024, 12:30:08 UTC, committed by xchhuang on 25 July 2024, 12:30:08 UTC
1 parent d7b3a6d
iadb_bn.py
import torch
import torchvision
from torchvision import transforms
from diffusers import UNet2DModel
from torch.optim import Adam, AdamW
import argparse
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import random
# import cv2
# import imageio
import platform
from scipy.stats import qmc
from scipy.stats import norm
import sys
import time
import yaml
import glob
sys.path.append('../')
sys.path.append('../../repo')
from bluenoise.get_noise_recent import get_noise_v2
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='celeba_small', help='dataset name')
parser.add_argument('--noise_type', type=str, default='gaussian', help='type of noise')
parser.add_argument('--optimizer_type', type=str, default='adamw', help='optimizer option')
parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--res', type=int, default=64, help='resolution')
parser.add_argument('--train_or_test', type=str, default='train', help='train_or_test')
parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint name')
parser.add_argument('--seed', type=int, default=0, help='seed')
parser.add_argument('--nb_steps', type=int, default=1000, help='nb_steps') # 128
parser.add_argument('--scheduler_alpha', type=str, default='linear', help='scheduler type')
parser.add_argument('--scheduler_gamma', type=str, default='linear', help='scheduler type')
parser.add_argument('--scheduler_param', type=float, default=0.02, help='scheduler parameter for scheduler_gamma')
parser.add_argument('--scheduler_param_s', type=float, default=0, help='scheduler parameter for scheduler_gamma')
parser.add_argument('--scheduler_param_e', type=float, default=3, help='scheduler parameter for scheduler_gamma')
parser.add_argument('--blue_noise_blur', type=float, default=None, help='blue noise blur')
parser.add_argument('--activation', type=str, default='silu', help='[silu, gelu, mish]')
parser.add_argument('--early_stopping_step', type=int, default=50, help='[200, 150, 100, 50]')
parser.add_argument('--split_step', type=int, default=900, help='experimentally chosen to be 600')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--mode_index', type=int, default=1, help='modes')
parser.add_argument('--reg_weight', type=float, default=1, help='weight of regularizer')
parser.add_argument('--alpha_min', type=float, default=0.0, help='min of alpha')
parser.add_argument('--grad_clip', type=float, default=None, help='grad norm clip')
parser.add_argument('--deterministic', type=int, default=1, help='deterministic or stochastic')
parser.add_argument("--resume_training", action="store_true", help="Whether to resume training")
parser.add_argument("--optimize_scheduler_param", action="store_true", help="Whether to optimize the scheduler_param")
parser.add_argument("--remap", action="store_true", help="remapping stratification across images")
parser.add_argument("--is_conditional", action="store_true", help="whether it is conditional image generation")
parser.add_argument('--conditional_type', type=str, default='superres', help='superres, coloring')
parser.add_argument("--fine_tune_mode_index", type=int, default=0, help="how to fine tune the model")
parser.add_argument("--skip", type=int, default=1, help="numbe of skipped steps")
parser.add_argument("--test_samples", type=int, default=10, help="numbe of generated samples")
parser.add_argument("--out_channel", type=int, default=6, help="out_channel is 3 or 6")
opt = parser.parse_args()
dimension = 3
seed = opt.seed
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alpha_min = opt.alpha_min
if platform.system() == 'Windows':
# opt.batch_size = 1 # 1, 2, 4
pass
cov_mat_L = np.load('./bluenoise/cov_gaussianBN_L_res{:}_d{:}.npz'.format(64, dimension))['x'].astype(np.float32)
if opt.noise_type in ['gaussianRN']:
cov_mat_L = np.load('./bluenoise/cov_gaussianRN_L_res{:}_d{:}.npz'.format(64, dimension))['x'].astype(np.float32)
cov_mat_L = torch.from_numpy(cov_mat_L).to(device).detach()
def get_scheduler(x, scheduler):
# input: x in [0, T], scheduler in ['linear', 'pow', 'cosinefaster', 'cosineslower', 'sigmoid', 'exp']
# output: e.g., exp(x)
scheduler = scheduler.lower()
array_type = None
if isinstance(x, np.ndarray):
array_type = 'numpy'
elif torch.is_tensor(x):
array_type = 'torch'
else:
array_type = 'float'
# do nothing
if scheduler == 'linear':
return x / opt.nb_steps
elif scheduler == 'sigmoid':
# if array_type in ['numpy', 'float']:
# x = np.power(x, opt.pow_index)
# elif array_type == 'torch':
# x = torch.pow(x, opt.pow_index)
# start = torch.zeros_like(x)
start = torch.ones_like(x) * opt.scheduler_param
end = torch.ones_like(x) * 3
clip_min = 1e-9
tau = 0.9 # 0.9 seems good; opt.scheduler_param
v_start = torch.nn.functional.sigmoid(start / tau)
v_end = torch.nn.functional.sigmoid(end / tau)
t = x / opt.nb_steps
output = torch.nn.functional.sigmoid((t * (end - start) + start) / tau)
output = (v_end - output) / (v_end - v_start)
output = torch.clamp(output, clip_min, 1)
x = 1 - output
elif scheduler == 'cosine':
start = torch.ones_like(x) * 0.2
end = torch.ones_like(x) * 1
clip_min = 1e-9
tau = opt.scheduler_param
v_start = torch.cos(start * np.pi / 2) ** (2 * tau)
v_end = torch.cos(end * np.pi / 2) ** (2 * tau)
t = x / opt.nb_steps
output = torch.cos((t * (end - start) + start) * np.pi / 2) ** (2 * tau)
output = (v_end - output) / (v_end - v_start)
output = torch.clamp(output, clip_min, 1.0)
x = 1 - output
else:
raise NotImplementedError
return x
def get_scheduler_gamma(x, scheduler, scheduler_params):
scheduler_param = scheduler_params[0]
scheduler_param_s = scheduler_params[1]
scheduler_param_e = scheduler_params[2]
# input: x in [0, T], scheduler in ['linear', 'pow', 'cosinefaster', 'cosineslower', 'sigmoid', 'exp']
# output: e.g., exp(x)
scheduler = scheduler.lower()
array_type = None
if isinstance(x, np.ndarray):
array_type = 'numpy'
elif torch.is_tensor(x):
array_type = 'torch'
else:
array_type = 'float'
if scheduler == 'linear':
return x / opt.nb_steps
elif scheduler == 'sigmoid':
start = torch.ones_like(x) * scheduler_param_s #scheduler_param
end = torch.ones_like(x) * scheduler_param_e
clip_min = 1e-9
tau = scheduler_param # 0.9 seems good; scheduler_param
v_start = torch.nn.functional.sigmoid(start / tau)
v_end = torch.nn.functional.sigmoid(end / tau)
t = x / opt.nb_steps
output = torch.nn.functional.sigmoid((t * (end - start) + start) / tau)
output = (v_end - output) / (v_end - v_start)
output = torch.clamp(output, clip_min, 1)
x = 1 - output
elif scheduler == 'cosine':
start = torch.ones_like(x) * scheduler_param_s
end = torch.ones_like(x) * scheduler_param_e
clip_min = 1e-9
tau = scheduler_param # integer for now
# # print('tau:', tau, torch.cos(start * np.pi / 2))
v_start = torch.pow(torch.cos(start * np.pi / 2.0), (2.0 * tau))
v_end = torch.pow(torch.cos(end * np.pi / 2), (2 * tau))
# # print('output:', v_end, v_start)
t = x / opt.nb_steps
output = torch.pow(torch.cos((t * (end - start) + start) * np.pi / 2), (2 * tau))
output = (v_end - output) / (v_end - v_start)
output = torch.clamp(output, clip_min, 1.0)
x = 1 - output
# t = x / opt.nb_steps
# output = 0.5 * torch.cos(2*np.pi*scheduler_param*t) + 0.5
# x = output
else:
raise NotImplementedError
return x
def get_model(inp_channel=3, out_channel=3):
# block_out_channels=(64, 64, 128, 128, 256, 256)
if opt.res in [64]:
block_out_channels=(128, 128, 256, 256, 512, 512)
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
# "DownBlock2D",
"DownBlock2D",
)
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
# "UpBlock2D", # a regular ResNet upsampling block
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
)
elif opt.res in [128]:
block_out_channels=(128, 128, 128, 256, 256, 512, 512)
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
# "DownBlock2D",
"DownBlock2D",
)
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
# "UpBlock2D", # a regular ResNet upsampling block
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
)
elif opt.res in [256]:
block_out_channels=(128, 128, 128, 128, 256, 256, 512, 512)
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
# "DownBlock2D",
"DownBlock2D",
)
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
# "UpBlock2D", # a regular ResNet upsampling block
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
)
else:
raise NotImplementedError
return UNet2DModel(block_out_channels=block_out_channels,out_channels=out_channel, in_channels=inp_channel, up_block_types=up_block_types, down_block_types=down_block_types, act_fn=opt.activation, add_attention=True)
@torch.no_grad()
def sample_iadb(model, x0, nb_step, scheduler_params):
# print('sample_iadb')
x_all = []
x_alpha = x0
start_step = 0#int(alpha_min * nb_step)
seq = list(range(start_step, nb_step))
use_reverse = True
if use_reverse:
seq = reversed(seq)
# print('nb_step:', seq, nb_step)
# for t in range(start_step, nb_step):
inference_time = []
for t in seq:
tt = torch.randint(low=t, high=t+1, size=(x0.shape[0], )).to(device)
# if use_reverse:
# alpha_start = ((t+1)/nb_step)
# alpha_end = ((t)/nb_step)
alpha_start = get_scheduler((tt + 1).float(), opt.scheduler_alpha)
alpha_end = get_scheduler(tt.float(), opt.scheduler_alpha)
# if opt.optimize_scheduler_param:
gamma_start = get_scheduler_gamma((tt + 1).float(), opt.scheduler_gamma, scheduler_params)
gamma_end = get_scheduler_gamma(tt.float(), opt.scheduler_gamma, scheduler_params)
start_time = time.time()
d = model(x_alpha, alpha_start, return_dict=False)[0]#['sample']
end_time = time.time()
inference_time.append(end_time - start_time)
if opt.noise_type in ['gaussianBN', 'gaussianRN']:
# x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d[:, :3, :, :] + (gamma_start - gamma_end).view(-1, 1, 1, 1) * alpha_end.view(-1, 1, 1, 1) * d[:, 3:, :, :]
if opt.out_channel == 3:
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d
elif opt.out_channel == 6:
# print('t:', t, gamma_start - gamma_end)
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d[:, :3, :, :] + (gamma_start - gamma_end).view(-1, 1, 1, 1) * d[:, 3:, :, :]
else:
raise NotImplementedError
elif opt.noise_type in ['gaussian', 'GBN']:
# TODO: for early stopping only
if False:
if opt.noise_type in ['GBN', 'gaussian'] and opt.early_stopping_step-1 == t:
x_alpha = d + x0
x_all.append(x_alpha)
break
else:
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d
if False: # motaivation figure
if t % 25 == 0 or t == 0:
x1_deblended = d + x0
if t != 0:
x1_deblended = (x1_deblended - x1_deblended.min()) / (x1_deblended.max() - x1_deblended.min())
else:
x1_deblended = torch.clamp((x_alpha + 1) / 2.0, 0, 1)
Image.fromarray((x1_deblended[1].permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)).save('results/motaivation_x1_deblended_{:}.png'.format(t))
# plt.figure(1)
# plt.subplot(121)
# plt.imshow(x1_deblended[1].permute(1,2,0).cpu().numpy())
# plt.subplot(122)
# plt.imshow(x0[1].permute(1,2,0).cpu().numpy())
# plt.show()
else:
raise NotImplementedError
if opt.train_or_test == 'test':
# if t == opt.early_stopping_step-1:
# x_all.append(x_alpha)
# break
if nb_step == 1000:
log_freq = 100
else:
log_freq = 25
if t % log_freq == 0 or t == nb_step-1:
x_all.append(x_alpha)
# print('inferece time:', inference_time, np.mean(inference_time[1:]))
if opt.train_or_test == 'test':
return x_alpha, x_all, np.mean(inference_time[1:])
return x_alpha
@torch.no_grad()
def sample_iadb_conditional(model, x0, x_c, nb_step, scheduler_params):
x_all = []
x_alpha = x0
start_step = 0#int(alpha_min * nb_step)
seq = list(range(start_step, nb_step))
use_reverse = True
if use_reverse:
seq = reversed(seq)
for t in seq:
tt = torch.randint(low=t, high=t+1, size=(x0.shape[0], )).to(device)
alpha_start = get_scheduler((tt + 1).float(), opt.scheduler_alpha)
alpha_end = get_scheduler(tt.float(), opt.scheduler_alpha)
# if opt.optimize_scheduler_param:
gamma_start = get_scheduler_gamma((tt + 1).float(), opt.scheduler_gamma, scheduler_params)
gamma_end = get_scheduler_gamma(tt.float(), opt.scheduler_gamma, scheduler_params)
# else:
# gamma_start = get_scheduler((tt + 1).float(), opt.scheduler_gamma)
# gamma_end = get_scheduler(tt.float(), opt.scheduler_gamma)
# print('x_alpha:', x_alpha.shape, x_c.shape)
d = model(torch.cat([x_alpha, x_c], 1), alpha_start, return_dict=False)[0]#['sample']
# d = model(x_alpha, alpha_start, return_dict=False)[0]#['sample']
if opt.noise_type in ['gaussianBN', 'gaussianRN']:
# x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d[:, :3, :, :] + (gamma_start - gamma_end).view(-1, 1, 1, 1) * alpha_end.view(-1, 1, 1, 1) * d[:, 3:, :, :]
if opt.out_channel == 3:
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d
elif opt.out_channel == 6:
# print('t:', t, gamma_start - gamma_end)
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d[:, :3, :, :] + (gamma_start - gamma_end).view(-1, 1, 1, 1) * d[:, 3:, :, :]
else:
raise NotImplementedError
elif opt.noise_type in ['gaussian', 'GBN']:
x_alpha = x_alpha + (alpha_start - alpha_end).view(-1, 1, 1, 1) * d
else:
raise NotImplementedError
if opt.train_or_test == 'test':
# if t == opt.early_stopping_step-1:
# x_all.append(x_alpha)
# break
if nb_step == 1000:
log_freq = 100
else:
log_freq = 25
if t % log_freq == 0 or t == nb_step-1:
x_all.append(x_alpha)
if opt.train_or_test == 'test':
return x_alpha, x_all
return x_alpha
DATA_FOLDER = './data/{:}'.format(opt.dataset)
transform = transforms.Compose([transforms.Resize(opt.res),transforms.CenterCrop(opt.res), transforms.RandomHorizontalFlip(0.5),transforms.ToTensor()])
test_transform = transforms.Compose([transforms.Resize(opt.res),transforms.CenterCrop(opt.res), transforms.ToTensor()])
start_time = time.time()
if opt.is_conditional:
if opt.train_or_test == 'train':
train_dataset = torchvision.datasets.ImageFolder(root=DATA_FOLDER+'_train', transform=transform)
else:
train_dataset = torchvision.datasets.ImageFolder(root=DATA_FOLDER+'_test', transform=transform) # dummy
test_dataset = torchvision.datasets.ImageFolder(root=DATA_FOLDER+'_test', transform=test_transform)
else:
train_dataset = torchvision.datasets.ImageFolder(root=DATA_FOLDER, transform=transform)
print('dataloader time, dataset size:', time.time() - start_time, len(train_dataset))
if platform.system() == 'Windows':
num_workers = 0
else:
if opt.res in [32, 64]:
num_workers = 4
elif opt.res in [128]:
num_workers = 8
elif opt.res in [256]:
num_workers = 16
else:
raise NotImplementedError
is_shuffle = True
if opt.train_or_test == 'test_amin':
is_shuffle = False
drop_last = True
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=is_shuffle, num_workers=num_workers, drop_last=drop_last) # True
device_count = torch.cuda.device_count()
print('device_count:', device_count)
if opt.noise_type in ['gaussianBN', 'gaussianRN']:
pass
else:
opt.out_channel = 3
if opt.is_conditional:
outer_folder = 'results_gaussianBN_{:}'.format(opt.conditional_type)
else:
outer_folder = 'results_gaussianBN'
if opt.scheduler_gamma in ['linear']:
output_folder = outer_folder + '/{:}_{:}_{:}_outc{:}_seed{:}'.format(opt.dataset, opt.noise_type, opt.scheduler_gamma, opt.out_channel, opt.seed)
else:
if opt.optimize_scheduler_param:
output_folder = outer_folder + '/{:}_{:}_{:}_outc{:}_seed{:}'.format(opt.dataset, opt.noise_type, opt.scheduler_gamma, opt.out_channel, opt.seed)
else:
if opt.remap:
output_folder = outer_folder + '/{:}_{:}_{:}_{:}_{:}_{:}_outc{:}_remap_seed{:}'.format(opt.dataset, opt.noise_type, opt.scheduler_gamma, opt.scheduler_param, opt.scheduler_param_s, opt.scheduler_param_e, opt.out_channel, opt.seed)
else:
output_folder = outer_folder + '/{:}_{:}_{:}_{:}_{:}_{:}_outc{:}_seed{:}'.format(opt.dataset, opt.noise_type, opt.scheduler_gamma, opt.scheduler_param, opt.scheduler_param_s, opt.scheduler_param_e, opt.out_channel, opt.seed)
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
def main():
print('output_folder:', output_folder)
nb_iter = 0
train_or_test = opt.train_or_test
if opt.optimize_scheduler_param:
# if opt.scheduler_gamma in ['cosine']:
# scheduler_param_min = 1
# scheduler_param_max = 10
if opt.scheduler_gamma in ['sigmoid']:
scheduler_param_min = 0.01
scheduler_param_max = 10
scheduler_param_s_min = -3
scheduler_param_s_max = -0.01
scheduler_param_e_min = 0.01
scheduler_param_e_max = 3
elif opt.scheduler_gamma in ['linear']:
scheduler_param_min = 1
scheduler_param_max = 1
scheduler_param_s_min = 1
scheduler_param_s_max = 1
scheduler_param_e_min = 1
scheduler_param_e_max = 1
else:
raise NotImplementedError
else:
scheduler_param_min = opt.scheduler_param
scheduler_param_max = opt.scheduler_param
scheduler_param_s_min = opt.scheduler_param_s
scheduler_param_s_max = opt.scheduler_param_s
scheduler_param_e_min = opt.scheduler_param_e
scheduler_param_e_max = opt.scheduler_param_e
scheduler_params = torch.rand(3).float().to(device)
scheduler_params[0] = scheduler_param_min + (scheduler_param_max - scheduler_param_min) * scheduler_params[0]
scheduler_params[1] = scheduler_param_s_min + (scheduler_param_s_max - scheduler_param_s_min) * scheduler_params[1]
scheduler_params[2] = scheduler_param_e_min + (scheduler_param_e_max - scheduler_param_e_min) * scheduler_params[2]
# print('scheduler_params:', scheduler_params)
if opt.optimize_scheduler_param:
# # experimentally set to better initialized values based on the image resolution
# if opt.dataset in ['cat_res64', 'celeba_res64']:
# scheduler_param[:] = 99 # as it converges to around 10
# elif opt.dataset in ['cat_res128', 'celeba_res128']:
# scheduler_param[:] = 0.2 # as it converges to around 0.2
# else:
# raise NotImplementedError
pass
# conditional image generation
is_conditional = opt.is_conditional
inp_chanel = 3
if is_conditional:
if opt.conditional_type in ['superres']:
inp_chanel = 6 # 6
# elif opt.conditional_type in ['coloring']:
# inp_chanel = 4
else:
raise NotImplementedError
model = get_model(inp_chanel, opt.out_channel)
if train_or_test == 'test' and is_conditional:
print('===> Start conditional sampling / superres')
from piq import psnr, ssim, SSIMLoss
model.load_state_dict(torch.load(f'{output_folder}/model.ckpt'))
model = model.to(device)
model = torch.nn.DataParallel(model) # multi-gpus
model.eval()
total_num = 5000 # 30000, 5000
# num_batch = int(total_num // opt.batch_size)
cnt = 0
# if total_num % opt.batch_size == 0:
# num_batch = int(total_num // opt.batch_size)
# last_batch_size = opt.batch_size
# else:
# num_batch = int(total_num // opt.batch_size) + 1
# last_batch_size = total_num - (num_batch-1) * opt.batch_size
if opt.noise_type in ['gaussianBN']:
folder_name_noise = 'gwn2gbn'
elif opt.noise_type in ['gaussian']:
folder_name_noise = 'gwn'
elif opt.noise_type in ['gaussianRN']:
folder_name_noise = 'gwn2grn'
elif opt.noise_type in ['GBN']:
folder_name_noise = 'gbn'
else:
raise NotImplementedError
folder_name = '{:}_iadb_{:}_{:}_steps{:}'.format(opt.dataset, folder_name_noise, opt.conditional_type, opt.nb_steps)
for sub_folder in ['images', 'seqs', 'lowres', 'highres']:
if not os.path.exists(output_folder + '/{:}/{:}'.format(folder_name, sub_folder)):
os.makedirs(output_folder + '/{:}/{:}'.format(folder_name, sub_folder), exist_ok=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=num_workers, drop_last=False)
if opt.optimize_scheduler_param:
scheduler_params = np.loadtxt(f'{output_folder}/scheduler_params.txt')
else:
scheduler_params = np.array([opt.scheduler_param, opt.scheduler_param_s, opt.scheduler_param_e]).astype(np.float32)
scheduler_params = torch.from_numpy(scheduler_params).float().to(device)
avg_ssim = 0
avg_psnr = 0
avg_l2 = 0
avg_l1 = 0
with torch.no_grad():
for i, data in enumerate(tqdm(test_dataloader)): # test_dataloader
if (i + 1) > 389:
return
if (i + 1) not in [74, 104, 278, 389]: # test the ones in the paper
continue
x1 = (data[0].to(device)*2)-1
downscale = 4
x_c = torch.nn.functional.interpolate(x1, size=(opt.res//downscale, opt.res//downscale), mode='bilinear', align_corners=True)
x_c = torch.nn.functional.interpolate(x_c, size=(opt.res, opt.res), mode='bilinear', align_corners=True)
cur_batch_size = x1.shape[0]
x0 = torch.from_numpy(np.random.randn(cur_batch_size, 3, opt.res, opt.res)).float().to(device)
t = torch.randint(low=opt.nb_steps, high=opt.nb_steps+1, size=(x0.shape[0], )).to(device)
gamma_t = get_scheduler_gamma(t.float(), opt.scheduler_gamma, scheduler_params)
x0, _, _ = get_noise_v2(device, x0, cov_mat_L, gamma_t, t, noise_type=opt.noise_type, train_or_test='test', inplace=True)
# print('x0:', x0.shape, x_c.shape)
sample, sample_all = sample_iadb_conditional(model, x0, x_c, opt.nb_steps, scheduler_params) # fair
ssim_val = ssim(torch.clamp((sample + 1.0) / 2.0, 0.0, 1.0), (x1 + 1.0) / 2.0, data_range=1., reduction='none')
psnr_val = psnr(torch.clamp((sample + 1.0) / 2.0, 0.0, 1.0), (x1 + 1.0) / 2.0, data_range=1., reduction='none')
l2_val = torch.sum((sample - x1) ** 2)
l1_val = torch.sum(torch.abs(sample - x1))
# print('val:', ssim_val, psnr_val, l2_val)
avg_ssim += torch.sum(ssim_val).item() / total_num
avg_psnr += torch.sum(psnr_val).item() / total_num
avg_l2 += l2_val.item() / total_num
avg_l1 += l1_val.item() / total_num
for j in range(0, len(sample_all), 1):
# plt.subplot(1, len(sample_all), j+1)
sample_plot = sample_all[j][0]
if j == len(sample_all) - 1:
sample_plot = torch.clamp((sample_plot + 1) / 2.0, 0.0, 1.0)
else:
sample_plot = (sample_plot - sample_plot.min()) / (sample_plot.max() - sample_plot.min())
Image.fromarray((sample_plot.permute(1, 2, 0).detach().cpu().numpy()*255).astype(np.uint8)).save(output_folder + '/{:}/seqs/{:}_img{:0>5}_step{:}.png'.format(folder_name, folder_name_noise, cnt, int((j*100)/1000*opt.nb_steps)))
for j in range(cur_batch_size):
cnt += 1
sample_j = sample[j]
# print('sample_j:', sample_j.shape, sample_j.min(), sample_j.max())
sample_j = torch.clamp((sample_j + 1) / 2.0, 0.0, 1.0)
sample_j = sample_j.permute(1, 2, 0).detach().cpu().numpy()
Image.fromarray((sample_j*255).astype(np.uint8)).save(output_folder + '/{:}/images/image_{:}_{:0>5}.png'.format(folder_name, folder_name_noise, cnt))
x1_plot = x1[j].permute(1, 2, 0).detach().cpu().numpy()
x1_plot = (x1_plot + 1) / 2
xc_plot = x_c[j].permute(1, 2, 0).detach().cpu().numpy()
xc_plot = (xc_plot + 1) / 2
err = np.mean(np.abs(x1_plot - sample_j))
# print('err:', cnt, err)
if opt.noise_type in ['gaussian']: # no need to save these for all experiments
Image.fromarray((x1_plot*255).astype(np.uint8)).save(output_folder + '/{:}/highres/highres_{:}_{:0>5}.png'.format(folder_name, folder_name_noise, cnt))
Image.fromarray((xc_plot*255).astype(np.uint8)).save(output_folder + '/{:}/lowres/lowres_{:}_{:0>5}.png'.format(folder_name, folder_name_noise, cnt))
# break
print('conditional metrics: ssim: {:.4f}, psnr: {:.4f}, l2: {:.4f}'.format(avg_ssim, avg_psnr, avg_l2))
return
if train_or_test == 'test':
print('===> Start unconditional sampling')
if opt.noise_type in ['gaussianBN']:
folder_name_noise = 'gwn2gbn'
elif opt.noise_type in ['gaussian']:
folder_name_noise = 'gwn'
# folder_name_noise = 'gwn_earlystop{:}'.format(opt.early_stopping_step)
elif opt.noise_type in ['gaussianRN']:
folder_name_noise = 'gwn2grn'
elif opt.noise_type in ['GBN']:
folder_name_noise = 'gbn'
# folder_name_noise = 'gbn_earlystop{:}'.format(opt.early_stopping_step)
else:
raise NotImplementedError
folder_name = '{:}_iadb_{:}_steps{:}'.format(opt.dataset, folder_name_noise, opt.nb_steps)
for sub_folder in ['images', 'seqs', 'noise']:
if not os.path.exists(output_folder + '/{:}/{:}/'.format(folder_name, sub_folder)):
os.makedirs(output_folder + '/{:}/{:}/'.format(folder_name, sub_folder), exist_ok=True)
current_num_samples = len(glob.glob(output_folder + '/{:}/images/*.png'.format(folder_name)))
current_num_samples = 0
start_batch = int(current_num_samples // opt.batch_size)
model.load_state_dict(torch.load(f'{output_folder}/model.ckpt'))
model = model.to(device)
model = torch.nn.DataParallel(model) # multi-gpus
model.eval()
total_num = opt.test_samples # 30000, 5000
# num_batch = int(total_num // opt.batch_size)
cnt = current_num_samples
if total_num % opt.batch_size == 0:
num_batch = int(total_num // opt.batch_size)
last_batch_size = opt.batch_size
else:
num_batch = int(total_num // opt.batch_size) + 1
last_batch_size = total_num - (num_batch-1) * opt.batch_size
print('current_samples:', current_num_samples)
print('start_batch:', start_batch)
print('num_batch:', num_batch)
if opt.optimize_scheduler_param:
scheduler_params = np.loadtxt(f'{output_folder}/scheduler_params.txt')
else:
scheduler_params = np.array([opt.scheduler_param, opt.scheduler_param_s, opt.scheduler_param_e]).astype(np.float32)
scheduler_params = torch.from_numpy(scheduler_params).float().to(device)
inference_times = []
noise_gen_times = []
for i in tqdm(range(start_batch, num_batch)):
# replicability, only one sample
if opt.dataset in ['cat_res64'] and i not in [4]:
continue
if opt.dataset in ['cat_res128'] and i not in [52]:
continue
if opt.dataset in ['celeba_res64'] and i not in [37]:
continue
if opt.dataset in ['celeba_res128'] and i not in [10]:
continue
if opt.dataset in ['church_res64'] and i not in [4, 23, 32, 36]:
continue
# print(opt.dataset, i, opt.batch_size)
cur_batch_size = opt.batch_size
if i == num_batch - 1:
cur_batch_size = last_batch_size
# x0 = torch.randn(cur_batch_size, 3, opt.res, opt.res).to(device) # weird
x0 = torch.from_numpy(np.random.randn(cur_batch_size, 3, opt.res, opt.res)).float().to(device)
if True:
x0 = np.load('./results_gaussianBN/{:}_gaussian_linear_outc3_seed0/{:}_iadb_gwn_steps250/noise/noise_batch{:}_idx{:05d}.npz'.format(opt.dataset, opt.dataset, opt.batch_size, i))['noise']
x0 = torch.from_numpy(x0).float().to(device)
x0 = x0[0:1] # replicability, only one sample
# print('x0:', x0.shape, x0.min(), x0.max())
t = torch.randint(low=opt.nb_steps, high=opt.nb_steps+1, size=(x0.shape[0], )).to(device)
gamma_t = get_scheduler_gamma(t.float(), opt.scheduler_gamma, scheduler_params)
start_time = time.time()
# x0, _, _ = get_noise_v2(device, x0, cov_mat_L, gamma_t, t, noise_type=opt.noise_type, train_or_test='test', inplace=True)
end_time = time.time()
noise_gen_times.append(end_time - start_time)
if False:
# print('x0:', x0.detach().cpu().numpy().shape)
np.savez_compressed(output_folder + '/{:}/noise/noise_batch{:}_idx{:0>5}.npz'.format(folder_name, cur_batch_size, i), noise=x0.detach().cpu().numpy())
# if i == 1:
# return
split_alpha = (opt.split_step) / 1000
ratio = opt.nb_steps / 1000
split_step = int(opt.split_step * ratio)
sample, sample_all, inference_time = sample_iadb(model, x0, opt.nb_steps, scheduler_params) # fair
inference_times.append(inference_time)
if True:
# plt.figure(1)
for j in range(0, len(sample_all), 1):
# plt.subplot(1, len(sample_all), j+1)
sample_plot = sample_all[j][0]
if j == len(sample_all) - 1:
sample_plot = torch.clamp((sample_plot + 1) / 2.0, 0.0, 1.0)
else:
sample_plot = (sample_plot - sample_plot.min()) / (sample_plot.max() - sample_plot.min())
Image.fromarray((sample_plot.permute(1, 2, 0).detach().cpu().numpy()*255).astype(np.uint8)).save(output_folder + '/{:}/seqs/{:}_img{:0>5}_step{:}.png'.format(folder_name, folder_name_noise, cnt, int((j*100)/1000*opt.nb_steps)))
# plt.imshow(sample_plot.permute(1,2,0).cpu().numpy())
# plt.axis('off')
# plt.savefig(output_folder + '/{:}/seqs/{:0>5}.png'.format(folder_name, i), bbox_inches='tight', pad_inches=0, dpi=300)
# plt.clf()
for j in range(cur_batch_size):
cnt += 1
if j > 0: # replicability, only one sample
continue
sample_j = sample[j]
sample_j = torch.clamp((sample_j + 1) / 2.0, 0.0, 1.0)
Image.fromarray((sample_j.permute(1, 2, 0).detach().cpu().numpy()*255).astype(np.uint8)).save(output_folder + '/{:}/images/{:0>5}.png'.format(folder_name, cnt))
print('np.mean(inference_times) per image with batch_size=1', np.mean(inference_times))
print('np.mean(noise_gen_times) per image with batch_size=1', np.mean(noise_gen_times[1:]))
return
print('===> Start training')
if opt.resume_training:
model.load_state_dict(torch.load(f'{output_folder}/model.ckpt'))
# model = model.to(device)
# model = torch.nn.DataParallel(model) # multi-gpus
else:
load_ckpt = False
if load_ckpt:
model.load_state_dict(torch.load(f'{output_folder}/model.ckpt'))
model = model.to(device)
model = torch.nn.DataParallel(model) # multi-gpus
if opt.optimizer_type in ['adam']:
optimizer = Adam(model.parameters(), lr=opt.lr)
elif opt.optimizer_type in ['adamw']:
optimizer = AdamW(model.parameters(), lr=opt.lr)
else:
raise NotImplementedError
optimizer_scheduler_param = AdamW([scheduler_params.requires_grad_()], lr=0.001) # 0.001, opt.lr
losses = []
scheduler_params_0 = []
scheduler_params_1 = []
scheduler_params_2 = []
iteration_count = 0
for current_epoch in tqdm(range(opt.epochs)):
model.train()
for i, data in enumerate(tqdm(dataloader)):
x1 = (data[0].to(device)*2)-1
bs = x1.shape[0]
# antithetic sampling following ddpm/ddim
upper_t = int(opt.nb_steps)
t = torch.randint(low=1, high=upper_t+1, size=(bs//2,)).to(device)
t = torch.cat([t, upper_t - t + 1], dim=0)[:bs]
alpha = t.float() / opt.nb_steps
# print('alpha1:', alpha)
alpha = get_scheduler(t.float(), opt.scheduler_alpha) # other scheduler: alpha = 1 - torch.sqrt(1 - alpha)
# print('alpha2:', alpha)
gamma_t = get_scheduler_gamma(t.float(), opt.scheduler_gamma, scheduler_params)
# print('gamma_t', gamma_t)
# x0: L_t @ noise
# noise_bn: L_b @ noise
# noise_wn: L_w @ noise
x0, noise_bn, noise_wn = get_noise_v2(device, x1, cov_mat_L, gamma_t, t, noise_type=opt.noise_type, train_or_test='train', inplace=False)
if opt.remap:
with torch.no_grad():
dist = torch.cdist(x0.view(bs, -1), x1.view(bs, -1))
mapping = torch.zeros(x0.shape[0], dtype=torch.long)
for i in range(x0.shape[0]):
mapping[i] = torch.argmin(dist[i])
# dist[:,mapping[i]] *= 100
dist[:,mapping[i]] = 10000
x1 = x1[mapping]
# print('mapping:', mapping)
# debug backward
if False:
gamma_t_1 = get_scheduler((t-1).float(), opt.scheduler_gamma)
alpha_t_1 = get_scheduler((t-1).float(), opt.scheduler_alpha)
noise = torch.randn_like(x1)
x0, noise_bn, noise_wn = get_noise_v2(device, noise, cov_mat_L, gamma_t, t, noise_type=opt.noise_type, train_or_test='train', inplace=True)
x0_1, _, _ = get_noise_v2(device, noise, cov_mat_L, gamma_t_1, t, noise_type=opt.noise_type, train_or_test='train', inplace=True)
# x_alpha = alpha.view(-1,1,1,1) * x0 + (1-alpha).view(-1,1,1,1) * x1
# x_alpha_t_1 = alpha_t_1.view(-1,1,1,1) * x0_1 + (1-alpha_t_1).view(-1,1,1,1) * x1
# x_alpha_t_1_recon = x_alpha + (alpha - alpha_t_1).view(-1, 1, 1, 1) * (x1 - x0) + (gamma_t - gamma_t_1).view(-1, 1, 1, 1) * (alpha_t_1).view(-1, 1, 1, 1) * (noise_bn - noise_wn)
# print('err:', torch.mean((x_alpha_t_1 - x_alpha_t_1_recon)**2))
# continue
return
iteration_count += 1
# x_alpha = alpha.view(-1,1,1,1) * x1 + (1-alpha).view(-1,1,1,1) * x0
x_alpha = alpha.view(-1,1,1,1) * x0 + (1-alpha).view(-1,1,1,1) * x1 # be careful: x1 is data, x0 is noise
if False:
print('alpha:', alpha, t)
plt.figure(1)
plt.subplot(121)
plt.imshow(x1[0].permute(1,2,0).detach().cpu().numpy())
plt.subplot(122)
plt.imshow(x0[0].permute(1,2,0).detach().cpu().numpy())
plt.show()
if is_conditional:
if opt.conditional_type in ['superres']:
# downsample and upsample to get bad initialized low-res image as input
downscale = 4
x_c = torch.nn.functional.interpolate(x1, size=(opt.res//downscale, opt.res//downscale), mode='bilinear', align_corners=True)
x_c = torch.nn.functional.interpolate(x_c, size=(opt.res, opt.res), mode='bilinear', align_corners=True)
# elif opt.conditional_type in ['coloring']:
# x_c = torch.nn.functional.interpolate(x1, size=(opt.res//2, opt.res//2), mode='bilinear', align_corners=True)
# x_c = torch.nn.functional.interpolate(x_c, size=(opt.res, opt.res), mode='bilinear', align_corners=True)
d = model(torch.cat([x_alpha, x_c], 1), alpha, return_dict=False)[0]
else:
d = model(x_alpha, alpha, return_dict=False)[0]
# d = model(x_alpha, t.float())
if opt.noise_type in ['gaussianBN', 'gaussianRN']:
alpha_t_minus_1 = get_scheduler((t - 1).float(), opt.scheduler_alpha)
if opt.out_channel == 3:
tar = x1 - x0 + alpha_t_minus_1.view(-1, 1, 1, 1) * (noise_bn - noise_wn)
loss = torch.sum((d - tar)**2)
elif opt.out_channel == 6:
tar1 = x1 - x0
tar2 = alpha_t_minus_1.view(-1, 1, 1, 1) * (noise_bn - noise_wn)
d1 = d[:, :3, ...]
d2 = d[:, 3:, ...]
# print(d2.shape, tar2.shape, alpha_t_minus_1.shape, gamma_t.shape, t, gamma_t)
gamma_t_minus_1 = get_scheduler_gamma((t-1).float(), opt.scheduler_gamma, scheduler_params)
delta_gamma_t = gamma_t - gamma_t_minus_1
delta_alpha_t = alpha - alpha_t_minus_1
loss1 = torch.sum((d1 - tar1)**2, dim=[1, 2, 3])
loss2 = torch.sum((d2 - tar2)**2, dim=[1, 2, 3])
loss1 = torch.sum(loss1 * delta_alpha_t / delta_alpha_t) # weight is simply 1
loss2 = torch.sum(loss2 * delta_gamma_t / delta_alpha_t) # weighted loss
loss = loss1 + loss2
else:
raise NotImplementedError
elif opt.noise_type in ['gaussian', 'GBN']:
loss = torch.sum((d - (x1-x0))**2)
else:
raise NotImplementedError
optimizer.zero_grad()
optimizer_scheduler_param.zero_grad()
loss.backward()
try:
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
except Exception:
pass
optimizer.step()
optimizer_scheduler_param.step()
nb_iter += 1
losses.append(loss.item())
scheduler_params[0].data.clamp_(scheduler_param_min, scheduler_param_max)
scheduler_params[1].data.clamp_(scheduler_param_s_min, scheduler_param_s_max)
scheduler_params[2].data.clamp_(scheduler_param_e_min, scheduler_param_e_max)
scheduler_params_0.append(scheduler_params[0].item())
scheduler_params_1.append(scheduler_params[1].item())
scheduler_params_2.append(scheduler_params[2].item())
# print('loss:', loss.item(), opt.noise_type)
# break
# continue
print('np.array(losses):', np.mean(np.array(losses)))
# if opt.optimize_scheduler_param:
print('scheduler_params: tau{:.4f},{:.4f},{:.4f}; start{:.4f},{:.4f},{:.4f}; end{:.4f},{:.4f},{:.4f}'.format(scheduler_params_0[-1], scheduler_param_min, scheduler_param_max, scheduler_params_1[-1], scheduler_param_s_min, scheduler_param_s_max, scheduler_params_2[-1], scheduler_param_e_min, scheduler_param_e_max))
# moving saving things outside training loop
plt.figure(1)
plt.plot(losses)
# plt.plot(np.mean(np.array(losses))) # track mean of losses
plt.savefig(output_folder + '/losses.png')
plt.clf()
np.savetxt(output_folder + '/losses.txt', np.array(losses))
plt.figure(1)
plt.plot(scheduler_params_0)
plt.plot(scheduler_params_1)
plt.plot(scheduler_params_2)
plt.savefig(output_folder + '/scheduler_params.png')
plt.clf()
np.savetxt(f'{output_folder}/scheduler_params.txt', scheduler_params.detach().cpu().numpy())
save_model_name = 'model'
torch.save(model.module.state_dict(), f'{output_folder}/{save_model_name}.ckpt')
if __name__ == '__main__':
main()
Computing file changes ...