https://github.com/MarkMoHR/virtual_sketching
Raw File
Tip revision: 958efe45a9120b9d467ba7701efba28c11e38f8f authored by Your Name on 21 August 2021, 08:13 UTC
Added bilibili links
Tip revision: 958efe4
utils.py
import os
import cv2
import json
import numpy as np
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt


#############################################
# Tensorflow utils
#############################################

def reset_graph():
    """Closes the current default session and resets the graph."""
    sess = tf.get_default_session()
    if sess:
        sess.close()
    tf.reset_default_graph()


def load_checkpoint(sess, checkpoint_path, ras_only=False, perceptual_only=False, gen_model_pretrain=False,
                    train_entire=False):
    if ras_only:
        load_var = {var.op.name: var for var in tf.global_variables() if 'raster_unit' in var.op.name}
    elif perceptual_only:
        load_var = {var.op.name: var for var in tf.global_variables() if 'VGG16' in var.op.name}
    elif train_entire:
        load_var = {var.op.name: var for var in tf.global_variables()
                    if 'discriminator' not in var.op.name
                    and 'raster_unit' not in var.op.name
                    and 'VGG16' not in var.op.name
                    and 'beta1' not in var.op.name
                    and 'beta2' not in var.op.name
                    and 'global_step' not in var.op.name
                    and 'Entire' not in var.op.name
                    }
    else:
        if gen_model_pretrain:
            load_var = {var.op.name: var for var in tf.global_variables()
                        if 'discriminator' not in var.op.name
                        and 'raster_unit' not in var.op.name
                        and 'VGG16' not in var.op.name
                        and 'beta1' not in var.op.name
                        and 'beta2' not in var.op.name
                        # and 'global_step' not in var.op.name
                        }
        else:
            load_var = tf.global_variables()

    restorer = tf.train.Saver(load_var)
    if not ras_only:
        ckpt = tf.train.get_checkpoint_state(checkpoint_path)
        model_checkpoint_path = ckpt.model_checkpoint_path
    else:
        model_checkpoint_path = checkpoint_path
    print('Loading model %s' % model_checkpoint_path)
    restorer.restore(sess, model_checkpoint_path)

    snapshot_step = model_checkpoint_path[model_checkpoint_path.rfind('-') + 1:]
    return snapshot_step


def create_summary(summary_writer, summ_map, step):
    for summ_key in summ_map:
        summ_value = summ_map[summ_key]
        summ = tf.summary.Summary()
        summ.value.add(tag=summ_key, simple_value=float(summ_value))
        summary_writer.add_summary(summ, step)
    summary_writer.flush()


def save_model(sess, saver, model_save_path, global_step):
    checkpoint_path = os.path.join(model_save_path, 'p2s')
    print('saving model %s.' % checkpoint_path)
    print('global_step %i.' % global_step)
    saver.save(sess, checkpoint_path, global_step=global_step)


#############################################
# Utils for basic image processing
#############################################


def normal(x, width):
    return (int)(x * (width - 1) + 0.5)


def draw(f, width=128):
    x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f
    x1 = x0 + (x2 - x0) * x1
    y1 = y0 + (y2 - y0) * y1
    x0 = normal(x0, width * 2)
    x1 = normal(x1, width * 2)
    x2 = normal(x2, width * 2)
    y0 = normal(y0, width * 2)
    y1 = normal(y1, width * 2)
    y2 = normal(y2, width * 2)
    z0 = (int)(1 + z0 * width // 2)
    z2 = (int)(1 + z2 * width // 2)
    canvas = np.zeros([width * 2, width * 2]).astype('float32')
    tmp = 1. / 100
    for i in range(100):
        t = i * tmp
        x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)
        y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)
        z = (int)((1-t) * z0 + t * z2)
        w = (1-t) * w0 + t * w2
        cv2.circle(canvas, (y, x), z, w, -1)
    return 1 - cv2.resize(canvas, dsize=(width, width))


def rgb_trans(split_num, break_values):
    slice_per_split = split_num // 8
    break_values_head, break_values_tail = break_values[:-1], break_values[1:]

    results = []

    for split_i in range(8):
        break_value_head = break_values_head[split_i]
        break_value_tail = break_values_tail[split_i]

        slice_gap = float(break_value_tail - break_value_head) / float(slice_per_split)
        for slice_i in range(slice_per_split):
            slice_val = break_value_head + slice_gap * slice_i
            slice_val = int(round(slice_val))
            results.append(slice_val)

    return results


def get_colors(color_num):
    split_num = (color_num // 8 + 1) * 8

    r_break_values = [0, 0, 0, 0, 128, 255, 255, 255, 128]
    g_break_values = [0, 0, 128, 255, 255, 255, 128, 0, 0]
    b_break_values = [128, 255, 255, 255, 128, 0, 0, 0, 0]

    r_rst_list = rgb_trans(split_num, r_break_values)
    g_rst_list = rgb_trans(split_num, g_break_values)
    b_rst_list = rgb_trans(split_num, b_break_values)

    assert len(r_rst_list) == len(g_rst_list)
    assert len(b_rst_list) == len(g_rst_list)

    rgb_color_list = [(r_rst_list[i], g_rst_list[i], b_rst_list[i]) for i in range(len(r_rst_list))]
    return rgb_color_list


#############################################
# Utils for testing
#############################################

def save_seq_data(save_root, save_filename, strokes_data, init_cursors, image_size, round_length, init_width):
    seq_save_root = os.path.join(save_root, 'seq_data')
    os.makedirs(seq_save_root, exist_ok=True)
    save_npz_path = os.path.join(seq_save_root, save_filename + '.npz')
    np.savez(save_npz_path, strokes_data=strokes_data, init_cursors=init_cursors,
             image_size=image_size, round_length=round_length, init_width=init_width)


def image_pasting_v3_testing(patch_image, cursor, image_size, window_size_f, pasting_func, sess):
    """
    :param patch_image:  (raster_size, raster_size), [0.0-BG, 1.0-stroke]
    :param cursor: (2), in size [0.0, 1.0)
    :param window_size_f: (), float32, [0.0, image_size)
    :return: (image_size, image_size), [0.0-BG, 1.0-stroke]
    """
    cursor_pos = cursor * float(image_size)
    pasted_image = sess.run(pasting_func.pasted_image,
                            feed_dict={pasting_func.patch_canvas: np.expand_dims(patch_image, axis=-1),
                                       pasting_func.cursor_pos_a: cursor_pos,
                                       pasting_func.image_size_a: image_size,
                                       pasting_func.window_size_a: window_size_f})
    # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
    pasted_image = pasted_image[:, :, 0]
    return pasted_image


def draw_strokes(data, save_root, save_filename, input_img, image_size, init_cursor, infer_lengths, init_width,
                 cursor_type, raster_size, min_window_size,
                 sess,
                 pasting_func=None,
                 save_seq=False, draw_order=False):
    """
    :param data: (N_strokes, 9): flag, x1, y1, x2, y2, r2, s2
    :return:
    """
    canvas = np.zeros((image_size, image_size), dtype=np.float32)  # [0.0-BG, 1.0-stroke]
    canvas_color = np.zeros((image_size, image_size, 3), dtype=np.float32)
    canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)
    frames = []

    cursor_idx = 0

    stroke_count = len(data)
    color_rgb_set = get_colors(stroke_count)  # list of (3,) in [0, 255]
    color_idx = 0

    for round_idx in range(len(infer_lengths)):
        round_length = infer_lengths[round_idx]

        cursor_pos = init_cursor[cursor_idx]  # (2)
        cursor_idx += 1

        prev_width = init_width
        prev_scaling = 1.0
        prev_window_size = raster_size  # (1)

        for round_inner_i in range(round_length):
            stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i

            curr_window_size = prev_scaling * prev_window_size
            curr_window_size = np.maximum(curr_window_size, min_window_size)
            curr_window_size = np.minimum(curr_window_size, image_size)

            pen_state = data[stroke_idx, 0]
            stroke_params = data[stroke_idx, 1:]  # (8)

            x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]
            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]
            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]
            widths = np.stack([prev_width, width2], axis=0)  # (2)
            stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1)  # (8)

            next_width = stroke_params[4]
            next_scaling = stroke_params[5]
            next_window_size = next_scaling * curr_window_size
            next_window_size = np.maximum(next_window_size, min_window_size)
            next_window_size = np.minimum(next_window_size, image_size)

            prev_width = next_width * curr_window_size / next_window_size
            prev_scaling = next_scaling
            prev_window_size = curr_window_size

            f = stroke_params_proc.tolist()  # (8)
            f += [1.0, 1.0]
            gt_stroke_img = draw(f)  # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
            gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos, image_size,
                                                            curr_window_size,
                                                            pasting_func, sess)  # [0.0-BG, 1.0-stroke]

            if pen_state == 0:
                canvas += gt_stroke_img_large  # [0.0-BG, 1.0-stroke]

            if draw_order:
                color_rgb = color_rgb_set[color_idx]  # (3) in [0, 255]
                color_idx += 1

                color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)
                color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)
                canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),
                                                                                     axis=-1) + color_stroke  # (H, W, 3)

                if pen_state == 0:
                    canvas_color = canvas_color * np.expand_dims((1.0 - gt_stroke_img_large),
                                                                 axis=-1) + color_stroke  # (H, W, 3)

            # update cursor_pos based on hps.cursor_type
            new_cursor_offsets = stroke_params[2:4] * (curr_window_size / 2.0)  # (1, 6), patch-level
            new_cursor_offset_next = new_cursor_offsets

            # important!!!
            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)

            cursor_pos_large = cursor_pos * float(image_size)

            stroke_position_next = cursor_pos_large + new_cursor_offset_next  # (2), large-level

            if cursor_type == 'next':
                cursor_pos_large = stroke_position_next  # (2), large-level
            else:
                raise Exception('Unknown cursor_type')

            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (2), large-level
            cursor_pos = cursor_pos_large / float(image_size)

            frames.append(canvas.copy())

    canvas = np.clip(canvas, 0.0, 1.0)
    canvas = np.round((1.0 - canvas) * 255.0).astype(np.uint8)  # [0-stroke, 255-BG]

    os.makedirs(save_root, exist_ok=True)
    save_path = os.path.join(save_root, save_filename)
    canvas_img = Image.fromarray(canvas, 'L')
    canvas_img.save(save_path, 'PNG')

    if save_seq:
        seq_save_root = os.path.join(save_root, 'seq', save_filename[:-4])
        os.makedirs(seq_save_root, exist_ok=True)
        for len_i in range(len(frames)):
            frame = frames[len_i]
            frame = np.round((1.0 - frame) * 255.0).astype(np.uint8)
            save_path = os.path.join(seq_save_root, str(len_i) + '.png')
            frame_img = Image.fromarray(frame, 'L')
            frame_img.save(save_path, 'PNG')

    if draw_order:
        order_save_root = os.path.join(save_root, 'order')
        order_comp_save_root = os.path.join(save_root, 'order-compare')
        os.makedirs(order_save_root, exist_ok=True)
        os.makedirs(order_comp_save_root, exist_ok=True)

        canvas_color = 255 - np.round(canvas_color * 255.0).astype(np.uint8)
        canvas_color_img = Image.fromarray(canvas_color, 'RGB')
        save_path = os.path.join(order_save_root, save_filename)
        canvas_color_img.save(save_path, 'PNG')

        canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)

        # comparsions
        rows = 2
        cols = 3
        plt.figure(figsize=(5 * cols, 5 * rows))

        plt.subplot(rows, cols, 1)
        plt.title('Input', fontsize=12)
        # plt.axis('off')
        input_rgb = input_img
        plt.imshow(input_rgb)

        # plt.subplot(rows, cols, 2)
        # plt.title('GT', fontsize=12)
        # # plt.axis('off')
        # gt_rgb = np.stack([gt_img for _ in range(3)], axis=2)
        # plt.imshow(gt_rgb)

        plt.subplot(rows, cols, 2)
        plt.title('Sketch', fontsize=12)
        # plt.axis('off')
        canvas_rgb = np.stack([canvas for _ in range(3)], axis=2)
        plt.imshow(canvas_rgb)

        plt.subplot(rows, cols, 4)
        plt.title('Sketch Order', fontsize=12)
        # plt.axis('off')
        plt.imshow(canvas_color)

        plt.subplot(rows, cols, 5)
        plt.title('Sketch Order with moving', fontsize=12)
        # plt.axis('off')
        plt.imshow(canvas_color_with_moving)

        plt.subplot(rows, cols, 6)
        plt.title('Order', fontsize=12)
        plt.axis('off')

        img_h = 5
        img_w = 10
        color_array = np.zeros([len(color_rgb_set) * img_h, img_w, 3], dtype=np.uint8)
        for i in range(len(color_rgb_set)):
            color_array[i * img_h: i * img_h + img_h, :, :] = color_rgb_set[i]

        plt.imshow(color_array)

        comp_save_path = os.path.join(order_comp_save_root, save_filename)
        plt.savefig(comp_save_path)
        plt.close()
        # plt.show()


def update_hyperparams(model_params, model_base_dir, model_name, infer_dataset):
    with tf.gfile.Open(os.path.join(model_base_dir, model_name, 'model_config.json'), 'r') as f:
        data = json.load(f)

    ignored_keys = ['image_size_small', 'image_size_large', 'z_size', 'raster_perc_loss_layer', 'raster_loss_wk',
                    'decreasing_sn', 'raster_loss_weight']
    for name in model_params._hparam_types.keys():
        if name not in data and name not in ignored_keys:
            raise Exception(name, 'not in model_config.json')

    assert data['resize_method'] == 'AREA'
    data['data_set'] = infer_dataset
    fix_list = ['use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
    for fix in fix_list:
        data[fix] = (data[fix] == 1)

    pop_keys = ['gpus', 'image_size', 'resolution_type', 'loop_per_gpu', 'stroke_num_loss_weight_end',
                'perc_loss_fuse_type',
                'early_pen_length', 'early_pen_loss_type', 'early_pen_loss_weight',
                'increase_start_steps', 'perc_loss_layers', 'sn_loss_type', 'photo_prob_end_step',
                'sup_weight', 'gan_weight', 'base_raster_loss_base_type']
    for pop_key in pop_keys:
        if pop_key in data.keys():
            data.pop(pop_key)

    model_params.parse_json(json.dumps(data))

    return model_params
back to top