https://github.com/MarkMoHR/virtual_sketching
Raw File
Tip revision: 958efe45a9120b9d467ba7701efba28c11e38f8f authored by Your Name on 21 August 2021, 08:13:15 UTC
Added bilibili links
Tip revision: 958efe4
gif_making.py
import os
import sys
import argparse
import numpy as np
from PIL import Image
import tensorflow as tf

sys.path.append('./')
from utils import draw, image_pasting_v3_testing
from model_common_test import DiffPastingV3

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def add_scaling_visualization(canvas_images, cursor, window_size, image_size):
    """
    :param canvas_images: (N, H, W, 3)
    :param cursor:
    :param window_size:
    :param image_size:
    :return:
    """
    cursor_pos = cursor * float(image_size)
    cursor_x, cursor_y = int(round(cursor_pos[0])), int(round(cursor_pos[1]))  # in large size

    vis_color = [255, 0, 0]
    cursor_width = 3
    box_width = 2

    canvas_imgs = 255 - np.round(canvas_images * 255.0).astype(np.uint8)

    # add cursor visualization
    canvas_imgs[:, cursor_y - cursor_width: cursor_y + cursor_width, cursor_x - cursor_width: cursor_x + cursor_width, :] = vis_color

    # add box visualization
    up = max(0, cursor_y - window_size // 2)
    down = min(image_size, cursor_y + window_size // 2)
    left = max(0, cursor_x - window_size // 2)
    right = min(image_size, cursor_x + window_size // 2)
    # up = cursor_y - window_size // 2
    # down = cursor_y + window_size // 2
    # left = cursor_x - window_size // 2
    # right = cursor_x + window_size // 2

    if up > 0:
        canvas_imgs[:, up: up + box_width, left: right, :] = vis_color
    if down < image_size:
        canvas_imgs[:, down - box_width: down, left: right, :] = vis_color
    if left > 0:
        canvas_imgs[:, up: down, left: left + box_width, :] = vis_color
    if right < image_size:
        canvas_imgs[:, up: down, right - box_width: right, :] = vis_color
    return canvas_imgs


def make_gif(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width,
             save_base,
             cursor_type='next', min_window_size=32, raster_size=128, add_box=True):
    """
    :param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2
    :return:
    """
    canvas = np.zeros((image_size, image_size), dtype=np.float32)  # [0.0-BG, 1.0-stroke]
    gif_frames = []

    cursor_idx = 0

    if init_cursor.ndim == 1:
        init_cursor = [init_cursor]

    for round_idx in range(len(infer_lengths)):
        print('Making progress', round_idx + 1, '/', len(infer_lengths))
        round_length = infer_lengths[round_idx]

        cursor_pos = init_cursor[cursor_idx]  # (2)
        cursor_idx += 1

        prev_width = init_width
        prev_scaling = 1.0
        prev_window_size = float(raster_size)  # (1)

        for round_inner_i in range(round_length):
            stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i

            curr_window_size_raw = prev_scaling * prev_window_size
            curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)
            curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)
            curr_window_size = int(round(curr_window_size_raw))  # ()

            pen_state = data[stroke_idx, 0]
            stroke_params = data[stroke_idx, 1:]  # (8)

            x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
            x0y0 = np.zeros_like(x2y2)  # (2), [-1.0, 1.0]
            x0y0 = np.divide(np.add(x0y0, 1.0), 2.0)  # (2), [0.0, 1.0]
            x2y2 = np.divide(np.add(x2y2, 1.0), 2.0)  # (2), [0.0, 1.0]
            widths = np.stack([prev_width, width2], axis=0)  # (2)
            stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1)  # (8)

            next_width = stroke_params[4]
            next_scaling = stroke_params[5]
            next_window_size = next_scaling * curr_window_size_raw
            next_window_size = np.maximum(next_window_size, min_window_size)
            next_window_size = np.minimum(next_window_size, image_size)

            prev_width = next_width * curr_window_size_raw / next_window_size
            prev_scaling = next_scaling
            prev_window_size = curr_window_size_raw

            f = stroke_params_proc.tolist()  # (8)
            f += [1.0, 1.0]
            gt_stroke_img = draw(f)  # (H, W), [0.0-stroke, 1.0-BG]

            gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos,
                                                           image_size,
                                                           curr_window_size_raw,
                                                           pasting_func, sess)  # [0.0-BG, 1.0-stroke]

            if pen_state == 0:
                canvas += gt_stroke_img_large  # [0.0-BG, 1.0-stroke]

            canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1)

            if add_box:
                vis_inputs = np.expand_dims(canvas_rgb, axis=0)
                vis_outputs = add_scaling_visualization(vis_inputs, cursor_pos, curr_window_size, image_size)
                canvas_vis = vis_outputs[0]
            else:
                canvas_vis = canvas_rgb

            canvas_vis_png = Image.fromarray(canvas_vis, 'RGB')
            gif_frames.append(canvas_vis_png)

            # update cursor_pos based on hps.cursor_type
            new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0)  # (1, 6), patch-level
            new_cursor_offset_next = new_cursor_offsets

            # important!!!
            new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)

            cursor_pos_large = cursor_pos * float(image_size)

            stroke_position_next = cursor_pos_large + new_cursor_offset_next  # (2), large-level

            if cursor_type == 'next':
                cursor_pos_large = stroke_position_next  # (2), large-level
            else:
                raise Exception('Unknown cursor_type')

            cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1))  # (2), large-level
            cursor_pos = cursor_pos_large / float(image_size)

    print('Saving to GIF ...')
    save_path = os.path.join(save_base, 'dynamic.gif')
    first_frame = gif_frames[0]
    first_frame.save(save_path, save_all=True, append_images=gif_frames, loop=0, duration=0.01)


def gif_making(npz_path):
    assert npz_path != ''

    min_window_size = 32
    raster_size = 128

    split_idx = npz_path.rfind('/')
    if split_idx == -1:
        file_base = './'
        file_name = npz_path[:-4]
    else:
        file_base = npz_path[:npz_path.rfind('/')]
        file_name = npz_path[npz_path.rfind('/') + 1: -4]

    gif_base = os.path.join(file_base, file_name)
    os.makedirs(gif_base, exist_ok=True)

    # differentiable pasting graph
    paste_v3_func = DiffPastingV3(raster_size)

    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=tfconfig)
    sess.run(tf.global_variables_initializer())

    data = np.load(npz_path, encoding='latin1', allow_pickle=True)
    strokes_data = data['strokes_data']
    init_cursors = data['init_cursors']
    image_size = data['image_size']
    round_length = data['round_length']
    init_width = data['init_width']

    if round_length.ndim == 0:
        round_lengths = [round_length]
    else:
        round_lengths = round_length

    # print('round_lengths', round_lengths)

    make_gif(sess, paste_v3_func,
             strokes_data, init_cursors, image_size, round_lengths, init_width,
             gif_base,
             min_window_size=min_window_size, raster_size=raster_size)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--file', '-f', type=str, default='', help="define a npz path")
    args = parser.parse_args()

    gif_making(args.file)
back to top