https://github.com/MarkMoHR/virtual_sketching
Tip revision: 958efe45a9120b9d467ba7701efba28c11e38f8f authored by Your Name on 21 August 2021, 08:13:15 UTC
Added bilibili links
Added bilibili links
Tip revision: 958efe4
utils.py
import os
import cv2
import json
import numpy as np
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
#############################################
# Tensorflow utils
#############################################
def reset_graph():
"""Closes the current default session and resets the graph."""
sess = tf.get_default_session()
if sess:
sess.close()
tf.reset_default_graph()
def load_checkpoint(sess, checkpoint_path, ras_only=False, perceptual_only=False, gen_model_pretrain=False,
train_entire=False):
if ras_only:
load_var = {var.op.name: var for var in tf.global_variables() if 'raster_unit' in var.op.name}
elif perceptual_only:
load_var = {var.op.name: var for var in tf.global_variables() if 'VGG16' in var.op.name}
elif train_entire:
load_var = {var.op.name: var for var in tf.global_variables()
if 'discriminator' not in var.op.name
and 'raster_unit' not in var.op.name
and 'VGG16' not in var.op.name
and 'beta1' not in var.op.name
and 'beta2' not in var.op.name
and 'global_step' not in var.op.name
and 'Entire' not in var.op.name
}
else:
if gen_model_pretrain:
load_var = {var.op.name: var for var in tf.global_variables()
if 'discriminator' not in var.op.name
and 'raster_unit' not in var.op.name
and 'VGG16' not in var.op.name
and 'beta1' not in var.op.name
and 'beta2' not in var.op.name
# and 'global_step' not in var.op.name
}
else:
load_var = tf.global_variables()
restorer = tf.train.Saver(load_var)
if not ras_only:
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
model_checkpoint_path = ckpt.model_checkpoint_path
else:
model_checkpoint_path = checkpoint_path
print('Loading model %s' % model_checkpoint_path)
restorer.restore(sess, model_checkpoint_path)
snapshot_step = model_checkpoint_path[model_checkpoint_path.rfind('-') + 1:]
return snapshot_step
def create_summary(summary_writer, summ_map, step):
for summ_key in summ_map:
summ_value = summ_map[summ_key]
summ = tf.summary.Summary()
summ.value.add(tag=summ_key, simple_value=float(summ_value))
summary_writer.add_summary(summ, step)
summary_writer.flush()
def save_model(sess, saver, model_save_path, global_step):
checkpoint_path = os.path.join(model_save_path, 'p2s')
print('saving model %s.' % checkpoint_path)
print('global_step %i.' % global_step)
saver.save(sess, checkpoint_path, global_step=global_step)
#############################################
# Utils for basic image processing
#############################################
def normal(x, width):
return (int)(x * (width - 1) + 0.5)
def draw(f, width=128):
x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f
x1 = x0 + (x2 - x0) * x1
y1 = y0 + (y2 - y0) * y1
x0 = normal(x0, width * 2)
x1 = normal(x1, width * 2)
x2 = normal(x2, width * 2)
y0 = normal(y0, width * 2)
y1 = normal(y1, width * 2)
y2 = normal(y2, width * 2)
z0 = (int)(1 + z0 * width // 2)
z2 = (int)(1 + z2 * width // 2)
canvas = np.zeros([width * 2, width * 2]).astype('float32')
tmp = 1. / 100
for i in range(100):
t = i * tmp
x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)
y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)
z = (int)((1-t) * z0 + t * z2)
w = (1-t) * w0 + t * w2
cv2.circle(canvas, (y, x), z, w, -1)
return 1 - cv2.resize(canvas, dsize=(width, width))
def rgb_trans(split_num, break_values):
slice_per_split = split_num // 8
break_values_head, break_values_tail = break_values[:-1], break_values[1:]
results = []
for split_i in range(8):
break_value_head = break_values_head[split_i]
break_value_tail = break_values_tail[split_i]
slice_gap = float(break_value_tail - break_value_head) / float(slice_per_split)
for slice_i in range(slice_per_split):
slice_val = break_value_head + slice_gap * slice_i
slice_val = int(round(slice_val))
results.append(slice_val)
return results
def get_colors(color_num):
split_num = (color_num // 8 + 1) * 8
r_break_values = [0, 0, 0, 0, 128, 255, 255, 255, 128]
g_break_values = [0, 0, 128, 255, 255, 255, 128, 0, 0]
b_break_values = [128, 255, 255, 255, 128, 0, 0, 0, 0]
r_rst_list = rgb_trans(split_num, r_break_values)
g_rst_list = rgb_trans(split_num, g_break_values)
b_rst_list = rgb_trans(split_num, b_break_values)
assert len(r_rst_list) == len(g_rst_list)
assert len(b_rst_list) == len(g_rst_list)
rgb_color_list = [(r_rst_list[i], g_rst_list[i], b_rst_list[i]) for i in range(len(r_rst_list))]
return rgb_color_list
#############################################
# Utils for testing
#############################################
def save_seq_data(save_root, save_filename, strokes_data, init_cursors, image_size, round_length, init_width):
seq_save_root = os.path.join(save_root, 'seq_data')
os.makedirs(seq_save_root, exist_ok=True)
save_npz_path = os.path.join(seq_save_root, save_filename + '.npz')
np.savez(save_npz_path, strokes_data=strokes_data, init_cursors=init_cursors,
image_size=image_size, round_length=round_length, init_width=init_width)
def image_pasting_v3_testing(patch_image, cursor, image_size, window_size_f, pasting_func, sess):
"""
:param patch_image: (raster_size, raster_size), [0.0-BG, 1.0-stroke]
:param cursor: (2), in size [0.0, 1.0)
:param window_size_f: (), float32, [0.0, image_size)
:return: (image_size, image_size), [0.0-BG, 1.0-stroke]
"""
cursor_pos = cursor * float(image_size)
pasted_image = sess.run(pasting_func.pasted_image,
feed_dict={pasting_func.patch_canvas: np.expand_dims(patch_image, axis=-1),
pasting_func.cursor_pos_a: cursor_pos,
pasting_func.image_size_a: image_size,
pasting_func.window_size_a: window_size_f})
# (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
pasted_image = pasted_image[:, :, 0]
return pasted_image
def draw_strokes(data, save_root, save_filename, input_img, image_size, init_cursor, infer_lengths, init_width,
cursor_type, raster_size, min_window_size,
sess,
pasting_func=None,
save_seq=False, draw_order=False):
"""
:param data: (N_strokes, 9): flag, x1, y1, x2, y2, r2, s2
:return:
"""
canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke]
canvas_color = np.zeros((image_size, image_size, 3), dtype=np.float32)
canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)
frames = []
cursor_idx = 0
stroke_count = len(data)
color_rgb_set = get_colors(stroke_count) # list of (3,) in [0, 255]
color_idx = 0
for round_idx in range(len(infer_lengths)):
round_length = infer_lengths[round_idx]
cursor_pos = init_cursor[cursor_idx] # (2)
cursor_idx += 1
prev_width = init_width
prev_scaling = 1.0
prev_window_size = raster_size # (1)
for round_inner_i in range(round_length):
stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i
curr_window_size = prev_scaling * prev_window_size
curr_window_size = np.maximum(curr_window_size, min_window_size)
curr_window_size = np.minimum(curr_window_size, image_size)
pen_state = data[stroke_idx, 0]
stroke_params = data[stroke_idx, 1:] # (8)
x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
widths = np.stack([prev_width, width2], axis=0) # (2)
stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8)
next_width = stroke_params[4]
next_scaling = stroke_params[5]
next_window_size = next_scaling * curr_window_size
next_window_size = np.maximum(next_window_size, min_window_size)
next_window_size = np.minimum(next_window_size, image_size)
prev_width = next_width * curr_window_size / next_window_size
prev_scaling = next_scaling
prev_window_size = curr_window_size
f = stroke_params_proc.tolist() # (8)
f += [1.0, 1.0]
gt_stroke_img = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size,
curr_window_size,
pasting_func, sess) # [0.0-BG, 1.0-stroke]
if pen_state == 0:
canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke]
if draw_order:
color_rgb = color_rgb_set[color_idx] # (3) in [0, 255]
color_idx += 1
color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)
color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)
canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + color_stroke # (H, W, 3)
if pen_state == 0:
canvas_color = canvas_color * np.expand_dims((1.0 - gt_stroke_img_large),
axis=-1) + color_stroke # (H, W, 3)
# update cursor_pos based on hps.cursor_type
new_cursor_offsets = stroke_params[2:4] * (curr_window_size / 2.0) # (1, 6), patch-level
new_cursor_offset_next = new_cursor_offsets
# important!!!
new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)
cursor_pos_large = cursor_pos * float(image_size)
stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level
if cursor_type == 'next':
cursor_pos_large = stroke_position_next # (2), large-level
else:
raise Exception('Unknown cursor_type')
cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level
cursor_pos = cursor_pos_large / float(image_size)
frames.append(canvas.copy())
canvas = np.clip(canvas, 0.0, 1.0)
canvas = np.round((1.0 - canvas) * 255.0).astype(np.uint8) # [0-stroke, 255-BG]
os.makedirs(save_root, exist_ok=True)
save_path = os.path.join(save_root, save_filename)
canvas_img = Image.fromarray(canvas, 'L')
canvas_img.save(save_path, 'PNG')
if save_seq:
seq_save_root = os.path.join(save_root, 'seq', save_filename[:-4])
os.makedirs(seq_save_root, exist_ok=True)
for len_i in range(len(frames)):
frame = frames[len_i]
frame = np.round((1.0 - frame) * 255.0).astype(np.uint8)
save_path = os.path.join(seq_save_root, str(len_i) + '.png')
frame_img = Image.fromarray(frame, 'L')
frame_img.save(save_path, 'PNG')
if draw_order:
order_save_root = os.path.join(save_root, 'order')
order_comp_save_root = os.path.join(save_root, 'order-compare')
os.makedirs(order_save_root, exist_ok=True)
os.makedirs(order_comp_save_root, exist_ok=True)
canvas_color = 255 - np.round(canvas_color * 255.0).astype(np.uint8)
canvas_color_img = Image.fromarray(canvas_color, 'RGB')
save_path = os.path.join(order_save_root, save_filename)
canvas_color_img.save(save_path, 'PNG')
canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)
# comparsions
rows = 2
cols = 3
plt.figure(figsize=(5 * cols, 5 * rows))
plt.subplot(rows, cols, 1)
plt.title('Input', fontsize=12)
# plt.axis('off')
input_rgb = input_img
plt.imshow(input_rgb)
# plt.subplot(rows, cols, 2)
# plt.title('GT', fontsize=12)
# # plt.axis('off')
# gt_rgb = np.stack([gt_img for _ in range(3)], axis=2)
# plt.imshow(gt_rgb)
plt.subplot(rows, cols, 2)
plt.title('Sketch', fontsize=12)
# plt.axis('off')
canvas_rgb = np.stack([canvas for _ in range(3)], axis=2)
plt.imshow(canvas_rgb)
plt.subplot(rows, cols, 4)
plt.title('Sketch Order', fontsize=12)
# plt.axis('off')
plt.imshow(canvas_color)
plt.subplot(rows, cols, 5)
plt.title('Sketch Order with moving', fontsize=12)
# plt.axis('off')
plt.imshow(canvas_color_with_moving)
plt.subplot(rows, cols, 6)
plt.title('Order', fontsize=12)
plt.axis('off')
img_h = 5
img_w = 10
color_array = np.zeros([len(color_rgb_set) * img_h, img_w, 3], dtype=np.uint8)
for i in range(len(color_rgb_set)):
color_array[i * img_h: i * img_h + img_h, :, :] = color_rgb_set[i]
plt.imshow(color_array)
comp_save_path = os.path.join(order_comp_save_root, save_filename)
plt.savefig(comp_save_path)
plt.close()
# plt.show()
def update_hyperparams(model_params, model_base_dir, model_name, infer_dataset):
with tf.gfile.Open(os.path.join(model_base_dir, model_name, 'model_config.json'), 'r') as f:
data = json.load(f)
ignored_keys = ['image_size_small', 'image_size_large', 'z_size', 'raster_perc_loss_layer', 'raster_loss_wk',
'decreasing_sn', 'raster_loss_weight']
for name in model_params._hparam_types.keys():
if name not in data and name not in ignored_keys:
raise Exception(name, 'not in model_config.json')
assert data['resize_method'] == 'AREA'
data['data_set'] = infer_dataset
fix_list = ['use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
pop_keys = ['gpus', 'image_size', 'resolution_type', 'loop_per_gpu', 'stroke_num_loss_weight_end',
'perc_loss_fuse_type',
'early_pen_length', 'early_pen_loss_type', 'early_pen_loss_weight',
'increase_start_steps', 'perc_loss_layers', 'sn_loss_type', 'photo_prob_end_step',
'sup_weight', 'gan_weight', 'base_raster_loss_base_type']
for pop_key in pop_keys:
if pop_key in data.keys():
data.pop(pop_key)
model_params.parse_json(json.dumps(data))
return model_params