https://github.com/mikgroup/sigpy
Raw File
Tip revision: bfd564bc9942f06dab6288923a2f9e7516237bdd authored by Jonathan Martin on 16 October 2023, 22:09:11 UTC
bump version 0.1.25 -> 0.1.26
Tip revision: bfd564b
plot.py
# -*- coding: utf-8 -*-
"""This module contains plotting functions based on matplotlib
for image, line, and scatter plots.

A feature of these plotting functions is that
they can be controlled using only hotkeys
so the user does not need to move away from the keyboard.

Given an array ``x``, an example usage is:

    >>> ImagePlot(x)
    >>> LinePlot(x)
    >>> ScatterPlot(x)

"""
import datetime
import os
import subprocess
import uuid

import numpy as np

import sigpy as sp

__all__ = ["ImagePlot", "LinePlot", "ScatterPlot"]


image_plot_help_str = r"""
$\bf{Hotkeys:}$
    $\bf{h:}$ show/hide hotkey menu.
    $\bf{x/y/z:}$ set current axis as x/y/z.
    $\bf{t:}$ swap between x and y.
    $\bf{c:}$ select current axis as color.
    $\bf{left/right:}$ change current axis.
    $\bf{up/down:}$ change slice along current axis.
    $\bf{a:}$ toggle hide all labels, titles and axes.
    $\bf{m/p/r/i/l:}$  magnitude/phase/real/imaginary/log mode.
    $\bf{[/]:}$ change brightness.
    $\bf{\{/\}:}$ change contrast.
    $\bf{s:}$ save as png.
    $\bf{g/v:}$ save as gif/video by along current axis.
    $\bf{q:}$ refresh.
    $\bf{0-9:}$ enter slice number.
    $\bf{enter:}$ set current axis as slice number.
"""


class ImagePlot(object):
    """Plot array as image.

    Press 'h' for a menu for hotkeys.

    Args:
        im (array): image numpy/cupy array.
        x (int): x axis.
        y (int): y axis.
        z (None or int): z axis.
        c (None or int): color axis.
        hide_axes (bool): toggle hiding axes, labels and title.
        mode (str): specify magnitude, phase, real, imaginary,
            and log mode. {'m', 'p', 'r', 'i', 'l'}.
        title (str): title.
        interpolation (str): plot interpolation.
        save_basename (str): saved png, gif, and video base name.
        fps (int): frame per seconds for gif and video.

    """

    def __init__(
        self,
        im,
        x=-1,
        y=-2,
        z=None,
        c=None,
        hide_axes=False,
        mode=None,
        colormap=None,
        vmin=None,
        vmax=None,
        title="",
        interpolation="nearest",
        save_basename="Figure",
        fps=10,
    ):
        if im.ndim < 2:
            raise TypeError(
                "Image dimension must at least be two, got {im_ndim}".format(
                    im_ndim=im.ndim
                )
            )
        import matplotlib.pyplot as plt

        self.axim = None
        self.im = im
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.shape = self.im.shape
        self.ndim = self.im.ndim
        self.slices = [s // 2 for s in self.shape]
        self.flips = [1] * self.ndim
        self.x = x % self.ndim
        self.y = y % self.ndim
        self.z = z % self.ndim if z is not None else None
        self.c = c % self.ndim if c is not None else None
        self.d = max(self.ndim - 3, 0)
        self.hide_axes = hide_axes
        self.show_help = False
        self.title = title
        self.interpolation = interpolation
        self.mode = mode
        self.colormap = colormap
        self.entering_slice = False
        self.vmin = vmin
        self.vmax = vmax
        self.save_basename = save_basename
        self.fps = fps
        self.help_text = None

        self.fig.canvas.mpl_disconnect(
            self.fig.canvas.manager.key_press_handler_id
        )
        self.fig.canvas.mpl_connect("key_press_event", self.key_press)
        self.update_axes()
        self.update_image()
        self.fig.canvas.draw()
        plt.show()

    def key_press(self, event):
        if event.key == "up":
            if self.d not in [self.x, self.y, self.z, self.c]:
                self.slices[self.d] = (self.slices[self.d] + 1) % self.shape[
                    self.d
                ]
            else:
                self.flips[self.d] *= -1

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "down":
            if self.d not in [self.x, self.y, self.z, self.c]:
                self.slices[self.d] = (self.slices[self.d] - 1) % self.shape[
                    self.d
                ]
            else:
                self.flips[self.d] *= -1

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "left":
            self.d = (self.d - 1) % self.ndim

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "right":
            self.d = (self.d + 1) % self.ndim

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "x" and self.d not in [self.x, self.z, self.c]:
            if self.d == self.y:
                self.x, self.y = self.y, self.x
            else:
                self.x = self.d

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "y" and self.d not in [self.y, self.z, self.c]:
            if self.d == self.x:
                self.x, self.y = self.y, self.x
            else:
                self.y = self.d

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "z" and self.d not in [self.x, self.y, self.c]:
            if self.d == self.z:
                self.z = None
            else:
                self.z = self.d

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif (
            event.key == "c"
            and self.d not in [self.x, self.y, self.z]
            and self.shape[self.d] == 3
        ):
            if self.d == self.c:
                self.c = None
            else:
                self.c = self.d

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "t":
            self.x, self.y = self.y, self.x

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "a":
            self.hide_axes = not self.hide_axes

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "f":
            self.fig.canvas.manager.full_screen_toggle()

        elif event.key == "q":
            self.vmin = None
            self.vmax = None
            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "]":
            width = self.vmax - self.vmin
            self.vmin -= width * 0.1
            self.vmax -= width * 0.1

            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "[":
            width = self.vmax - self.vmin
            self.vmin += width * 0.1
            self.vmax += width * 0.1

            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "}":
            width = self.vmax - self.vmin
            center = (self.vmax + self.vmin) / 2
            self.vmin = center - width * 1.1 / 2
            self.vmax = center + width * 1.1 / 2

            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "{":
            width = self.vmax - self.vmin
            center = (self.vmax + self.vmin) / 2
            self.vmin = center - width * 0.9 / 2
            self.vmax = center + width * 0.9 / 2

            self.update_image()
            self.fig.canvas.draw()

        elif event.key in ["m", "p", "r", "i", "l"]:
            self.vmin = None
            self.vmax = None
            self.mode = event.key

            self.update_axes()
            self.update_image()
            self.fig.canvas.draw()

        elif event.key == "s":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.png"
            )
            self.fig.savefig(
                filename,
                transparent=True,
                format="png",
                bbox_inches="tight",
                pad_inches=0,
            )

        elif event.key == "g":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.gif"
            )
            temp_basename = uuid.uuid4()

            bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer())
            for i in range(self.shape[self.d]):
                self.slices[self.d] = i

                self.update_axes()
                self.update_image()
                self.fig.canvas.draw()
                self.fig.savefig(
                    "{} {:05d}.png".format(temp_basename, i),
                    format="png",
                    bbox_inches=bbox,
                    pad_inches=0,
                )

            subprocess.run(
                [
                    "ffmpeg",
                    "-f",
                    "image2",
                    "-s",
                    "{}x{}".format(
                        int(bbox.width * self.fig.dpi),
                        int(bbox.height * self.fig.dpi),
                    ),
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-vf",
                    "palettegen",
                    "{} palette.png".format(temp_basename),
                ]
            )
            subprocess.run(
                [
                    "ffmpeg",
                    "-f",
                    "image2",
                    "-s",
                    "{}x{}".format(
                        int(bbox.width * self.fig.dpi),
                        int(bbox.height * self.fig.dpi),
                    ),
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-i",
                    "{} palette.png".format(temp_basename),
                    "-lavfi",
                    "paletteuse",
                    filename,
                ]
            )

            os.remove("{} palette.png".format(temp_basename))
            for i in range(self.shape[self.d]):
                os.remove("{} {:05d}.png".format(temp_basename, i))

        elif event.key == "v":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.mp4"
            )
            temp_basename = uuid.uuid4()

            for i in range(self.shape[self.d]):
                self.slices[self.d] = i

                self.update_axes()
                self.update_image()
                self.fig.canvas.draw()
                self.fig.savefig(
                    "{} {:05d}.png".format(temp_basename, i),
                    format="png",
                    transparent=True,
                    bbox_inches="tight",
                    pad_inches=0,
                )

            subprocess.run(
                [
                    "ffmpeg",
                    "-r",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-vf",
                    "crop=floor(iw/2)*2-10:floor(ih/2)*2-10",
                    "-pix_fmt",
                    "yuv420p",
                    "-crf",
                    "1",
                    "-vcodec",
                    "libx264",
                    "-preset",
                    "veryslow",
                    filename,
                ]
            )

            for i in range(self.shape[self.d]):
                os.remove("{} {:05d}.png".format(temp_basename, i))

        elif event.key in [
            "0",
            "1",
            "2",
            "3",
            "4",
            "5",
            "6",
            "7",
            "8",
            "9",
            "backspace",
        ] and self.d not in [self.x, self.y, self.z, self.c]:
            if self.entering_slice:
                if event.key == "backspace":
                    if self.entered_slice < 10:
                        self.entering_slice = False
                    else:
                        self.entered_slice //= 10
                else:
                    self.entered_slice = self.entered_slice * 10 + int(
                        event.key
                    )
            elif event.key != "backspace":
                self.entering_slice = True
                self.entered_slice = int(event.key)

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "enter" and self.entering_slice:
            self.entering_slice = False
            if self.entered_slice < self.shape[self.d]:
                self.slices[self.d] = self.entered_slice

                self.update_image()

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "h":
            self.show_help = not self.show_help

            self.update_image()
            self.fig.canvas.draw()
        else:
            return

    def update_image(self):
        # Extract slice.
        idx = []
        for i in range(self.ndim):
            if i in [self.x, self.y, self.z, self.c]:
                idx.append(slice(None, None, self.flips[i]))
            else:
                idx.append(self.slices[i])

        idx = tuple(idx)
        imv = sp.to_device(self.im[idx])

        # Transpose to have [z, y, x, c].
        imv_dims = [self.y, self.x]
        if self.z is not None:
            imv_dims = [self.z] + imv_dims

        if self.c is not None:
            imv_dims = imv_dims + [self.c]

        imv = np.transpose(imv, np.argsort(np.argsort(imv_dims)))
        imv = array_to_image(imv, color=self.c is not None)

        if self.mode is None:
            if np.isrealobj(imv):
                self.mode = "r"
            else:
                self.mode = "m"

        if self.mode == "m":
            imv = np.abs(imv)
        elif self.mode == "p":
            imv = np.angle(imv)
        elif self.mode == "r":
            imv = np.real(imv)
        elif self.mode == "i":
            imv = np.imag(imv)
        elif self.mode == "l":
            imv = np.abs(imv)
            imv = np.log(imv, out=np.ones_like(imv) * -31, where=imv != 0)

        if self.vmin is None:
            self.vmin = imv.min()

        if self.vmax is None:
            self.vmax = imv.max()

        if self.axim is None:
            if self.colormap is None:
                colormap = "gray"
            else:
                colormap = self.colormap
            self.axim = self.ax.imshow(
                imv,
                vmin=self.vmin,
                vmax=self.vmax,
                cmap=colormap,
                origin="lower",
                interpolation=self.interpolation,
                aspect=1.0,
                extent=[0, imv.shape[1], 0, imv.shape[0]],
            )

            if self.colormap is not None:
                self.fig.colorbar(self.axim)

        else:
            self.axim.set_data(imv)
            self.axim.set_extent([0, imv.shape[1], 0, imv.shape[0]])
            self.axim.set_clim(self.vmin, self.vmax)

        if self.help_text is None:
            bbox_props = dict(
                boxstyle="round", pad=1, fc="white", alpha=0.95, lw=0
            )
            self.help_text = self.ax.text(
                imv.shape[0] / 2,
                imv.shape[1] / 2,
                image_plot_help_str,
                ha="center",
                va="center",
                linespacing=1.5,
                ma="left",
                size=8,
                bbox=bbox_props,
            )

        self.help_text.set_visible(self.show_help)

    def update_axes(self):
        if not self.hide_axes:
            caption = "["
            for i in range(self.ndim):
                if i == self.d:
                    caption += "["
                else:
                    caption += " "

                if self.flips[i] == -1 and (
                    i == self.x or i == self.y or i == self.z or i == self.c
                ):
                    caption += "-"

                if i == self.x:
                    caption += "x"
                elif i == self.y:
                    caption += "y"
                elif i == self.z:
                    caption += "z"
                elif i == self.c:
                    caption += "c"
                elif i == self.d and self.entering_slice:
                    caption += str(self.entered_slice) + "_"
                else:
                    caption += str(self.slices[i])

                if i == self.d:
                    caption += "]"
                else:
                    caption += " "
            caption += "]"

            self.ax.set_title(caption)
            self.fig.suptitle(self.title)
            self.ax.xaxis.set_visible(True)
            self.ax.yaxis.set_visible(True)
            self.ax.title.set_visible(True)
        else:
            self.ax.set_title("")
            self.fig.suptitle("")
            self.ax.xaxis.set_visible(False)
            self.ax.yaxis.set_visible(False)
            self.ax.title.set_visible(False)


def mosaic_shape(batch):
    mshape = [int(batch**0.5), batch // int(batch**0.5)]

    while sp.prod(mshape) < batch:
        mshape[1] += 1

    if (mshape[0] - 1) * (mshape[1] + 1) == batch:
        mshape[0] -= 1
        mshape[1] += 1

    return tuple(mshape)


def array_to_image(arr, color=False):
    """
    Flattens all dimensions except the last two

    Args:
        arr (array): shape [z, x, y, c] if color, else [z, x, y]

    """
    if color and not (arr.max() == 0 and arr.min() == 0):
        arr = arr / np.abs(arr).max()

    if arr.ndim == 2:
        return arr
    elif color and arr.ndim == 3:
        return arr

    if color:
        img_shape = arr.shape[-3:]
        batch = sp.prod(arr.shape[:-3])
        mshape = mosaic_shape(batch)
    else:
        img_shape = arr.shape[-2:]
        batch = sp.prod(arr.shape[:-2])
        mshape = mosaic_shape(batch)

    if sp.prod(mshape) == batch:
        img = arr.reshape((batch,) + img_shape)
    else:
        img = np.zeros((sp.prod(mshape),) + img_shape, dtype=arr.dtype)
        img[:batch, ...] = arr.reshape((batch,) + img_shape)

    img = img.reshape(mshape + img_shape)
    if color:
        img = np.transpose(img, (0, 2, 1, 3, 4))
        img = img.reshape(
            (img_shape[0] * mshape[0], img_shape[1] * mshape[1], 3)
        )
    else:
        img = np.transpose(img, (0, 2, 1, 3))
        img = img.reshape((img_shape[0] * mshape[0], img_shape[1] * mshape[1]))

    return img


class LinePlot(object):
    """Plot array as lines.

    Keyword Args:
        x: select current dimension as x
        left/right: increment/decrement current dimension
        up/down: flip axis when current dimension is x or y
            otherwise increment/decrement slice at current dimension
        h: toggle hide all labels, titles and axes
        m: magnitude mode
        p: phase mode
        r: real mode
        i: imaginary mode
        l: log mode
        s: save as png.
        g: save as gif by traversing current dimension.
        v: save as video by traversing current dimension.
    """

    def __init__(
        self,
        arr,
        x=-1,
        hide_axes=False,
        mode="m",
        title="",
        save_basename="Figure",
        fps=10,
    ):
        import matplotlib.pyplot as plt

        self.arr = arr
        self.axarr = None

        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.shape = self.arr.shape
        self.ndim = self.arr.ndim
        self.slices = [s // 2 for s in self.shape]
        self.flips = [1] * self.ndim
        self.x = x % self.ndim
        self.d = max(self.ndim - 3, 0)
        self.hide_axes = hide_axes
        self.title = title
        self.mode = mode
        self.save_basename = save_basename
        self.fps = fps
        self.bottom = None
        self.top = None

        self.fig.canvas.mpl_disconnect(
            self.fig.canvas.manager.key_press_handler_id
        )
        self.fig.canvas.mpl_connect("key_press_event", self.key_press)
        self.update_axes()
        self.update_line()
        self.fig.canvas.draw()
        plt.show()

    def key_press(self, event):
        if event.key == "up":
            if self.d != self.x:
                self.slices[self.d] = (self.slices[self.d] + 1) % self.shape[
                    self.d
                ]
            else:
                self.flips[self.d] *= -1

            self.update_axes()
            self.update_line()
            self.fig.canvas.draw()

        elif event.key == "down":
            if self.d != self.x:
                self.slices[self.d] = (self.slices[self.d] - 1) % self.shape[
                    self.d
                ]
            else:
                self.flips[self.d] *= -1

            self.update_axes()
            self.update_line()
            self.fig.canvas.draw()

        elif event.key == "left":
            self.d = (self.d - 1) % self.ndim

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "right":
            self.d = (self.d + 1) % self.ndim

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "x" and self.d != self.x:
            self.x = self.d

            self.update_axes()
            self.update_line()
            self.fig.canvas.draw()

        elif event.key == "a":
            self.hide_axes = not self.hide_axes

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "f":
            self.fig.canvas.manager.full_screen_toggle()

        elif (
            event.key == "m"
            or event.key == "p"
            or event.key == "r"
            or event.key == "i"
            or event.key == "l"
        ):
            self.mode = event.key
            self.bottom = None
            self.top = None

            self.update_axes()
            self.update_line()
            self.fig.canvas.draw()

        elif event.key == "s":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.png"
            )
            self.fig.savefig(
                filename,
                transparent=True,
                format="png",
                bbox_inches="tight",
                pad_inches=0,
            )

        elif event.key == "g":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.gif"
            )
            temp_basename = uuid.uuid4()

            bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer())
            for i in range(self.shape[self.d]):
                self.slices[self.d] = i

                self.update_axes()
                self.update_line()
                self.fig.canvas.draw()
                self.fig.savefig(
                    "{} {:05d}.png".format(temp_basename, i),
                    format="png",
                    bbox_inches=bbox,
                    pad_inches=0,
                )

            subprocess.run(
                [
                    "ffmpeg",
                    "-f",
                    "image2",
                    "-s",
                    "{}x{}".format(
                        int(bbox.width * self.fig.dpi),
                        int(bbox.height * self.fig.dpi),
                    ),
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-vf",
                    "palettegen",
                    "{} palette.png".format(temp_basename),
                ]
            )
            subprocess.run(
                [
                    "ffmpeg",
                    "-f",
                    "image2",
                    "-s",
                    "{}x{}".format(
                        int(bbox.width * self.fig.dpi),
                        int(bbox.height * self.fig.dpi),
                    ),
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-i",
                    "{} palette.png".format(temp_basename),
                    "-lavfi",
                    "paletteuse",
                    filename,
                ]
            )

            os.remove("{} palette.png".format(temp_basename))
            for i in range(self.shape[self.d]):
                os.remove("{} {:05d}.png".format(temp_basename, i))

        elif event.key == "v":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %h.%M.%S %p.mov"
            )
            temp_basename = uuid.uuid4()

            bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer())
            for i in range(self.shape[self.d]):
                self.slices[self.d] = i

                self.update_axes()
                self.update_line()
                self.fig.canvas.draw()
                self.fig.savefig(
                    "{} {:05d}.png".format(temp_basename, i),
                    format="png",
                    bbox_inches=bbox,
                    pad_inches=0,
                )

            subprocess.run(
                [
                    "ffmpeg",
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-vcodec",
                    "png",
                    filename,
                ]
            )

            for i in range(self.shape[self.d]):
                os.remove("{} {:05d}.png".format(temp_basename, i))
        else:
            return

        return

    def update_line(self):
        order = [i for i in range(self.ndim) if i != self.x] + [self.x]
        idx = tuple(
            [self.slices[i] for i in order[:-1]]
            + [slice(None, None, self.flips[self.x])]
        )

        arrv = self.arr.transpose(order)[idx]

        if self.mode == "m":
            arrv = np.abs(arrv)
        elif self.mode == "p":
            arrv = np.angle(arrv)
        elif self.mode == "r":
            arrv = np.real(arrv)
        elif self.mode == "i":
            arrv = np.imag(arrv)
        elif self.mode == "l":
            eps = 1e-31
            arrv = np.log(np.abs(arrv) + eps)

        if self.bottom is None:
            self.bottom = arrv.min()

        if self.top is None:
            self.top = arrv.max()

        if self.axarr is None:
            self.axarr = self.ax.plot(arrv)[0]

        else:
            self.axarr.set_xdata(np.arange(len(arrv)))
            self.axarr.set_ydata(arrv)
            self.ax.set_ylim(self.bottom, self.top)

    def update_axes(self):
        if not self.hide_axes:
            caption = "Slice: ["
            for i in range(self.ndim):
                if i == self.d:
                    caption += "["
                else:
                    caption += " "

                if self.flips[i] == -1 and i == self.x:
                    caption += "-"

                if i == self.x:
                    caption += "x"
                else:
                    caption += str(self.slices[i])

                if i == self.d:
                    caption += "]"
                else:
                    caption += " "
            caption += "]"

            self.ax.set_title(caption)
            self.ax.axis("on")
            self.fig.suptitle(self.title)
            self.ax.xaxis.set_visible(True)
            self.ax.yaxis.set_visible(True)
            self.ax.title.set_visible(True)
        else:
            self.ax.set_title("")
            self.fig.suptitle("")
            self.ax.xaxis.set_visible(False)
            self.ax.yaxis.set_visible(False)
            self.ax.title.set_visible(False)


class ScatterPlot(object):
    """Plot array as scatter.

    Keyword Args:
        z: toggle current dimension as z dimension
        left/right: increment/decrement current dimension
        up/down: flip axis when current dimension is x or y
            otherwise increment/decrement slice at current dimension
        h: toggle hide all labels, titles and axes
        m: magnitude mode
        p: phase mode
        r: real mode
        i: imaginary mode
        l: log mode
    """

    def __init__(
        self,
        coord,
        data=None,
        z=None,
        hide_axes=False,
        mode="m",
        title="",
        save_basename="Figure",
        fps=10,
    ):
        import matplotlib.pyplot as plt

        self.coord = coord
        assert coord.shape[-1] == 2
        if data is None:
            self.data = np.ones(coord.shape[:-1])
        else:
            self.data = data

        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.ax.set_facecolor("k")
        self.ax.axis("equal")

        for c, d in zip(coord.shape[:-1], self.data.shape[-coord.ndim + 1 :]):
            assert c == d

        self.ndim = self.data.ndim - self.coord.ndim + 1
        self.shape = self.data.shape[: self.ndim]

        self.slices = [s // 2 for s in self.shape]
        self.flips = [1] * self.ndim
        self.z = z % self.ndim if z is not None else None
        self.d = 0
        self.hide_axes = hide_axes
        self.title = title
        self.mode = mode
        self.axsc = None
        self.entering_slice = False
        self.save_basename = save_basename
        self.fps = fps
        self.vmin = None
        self.vmax = None

        self.fig.canvas.mpl_disconnect(
            self.fig.canvas.manager.key_press_handler_id
        )
        self.fig.canvas.mpl_connect("key_press_event", self.key_press)
        self.update_axes()
        self.update_data()
        self.fig.canvas.draw()
        plt.show()

    def key_press(self, event):
        if event.key == "up":
            if self.d != self.z:
                self.slices[self.d] = (self.slices[self.d] + 1) % self.shape[
                    self.d
                ]
            else:
                self.flips[self.d] *= -1

            self.update_axes()
            self.update_data()
            self.fig.canvas.draw()

        elif event.key == "down":
            if self.d != self.z:
                self.slices[self.d] = (self.slices[self.d] - 1) % self.shape[
                    self.d
                ]
            else:
                self.flips[self.d] *= -1

            self.update_axes()
            self.update_data()
            self.fig.canvas.draw()

        elif event.key == "left":
            self.d = (self.d - 1) % self.ndim

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "right":
            self.d = (self.d + 1) % self.ndim

            self.update_axes()
            self.fig.canvas.draw()

        # elif event.key == 'z':
        #     if self.d == self.z:
        #         self.z = None
        #     else:
        #         self.z = self.d

        #     self.update_axes()
        #     self.update_data()
        #     self.fig.canvas.draw()

        elif event.key == "a":
            self.hide_axes = not self.hide_axes

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "f":
            self.fig.canvas.manager.full_screen_toggle()

        elif (
            event.key == "m"
            or event.key == "p"
            or event.key == "r"
            or event.key == "i"
            or event.key == "l"
        ):
            self.mode = event.key
            self.vmin = None
            self.vmax = None

            self.update_axes()
            self.update_data()
            self.fig.canvas.draw()

        elif event.key == "s":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.png"
            )
            self.fig.savefig(
                filename,
                transparent=True,
                format="png",
                bbox_inches="tight",
                pad_inches=0,
            )

        elif event.key == "g":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %I.%M.%S %p.gif"
            )
            temp_basename = uuid.uuid4()

            bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer())
            for i in range(self.shape[self.d]):
                self.slices[self.d] = i

                self.update_axes()
                self.update_data()
                self.fig.canvas.draw()
                self.fig.savefig(
                    "{} {:05d}.png".format(temp_basename, i),
                    format="png",
                    bbox_inches=bbox,
                    pad_inches=0,
                )

            subprocess.run(
                [
                    "ffmpeg",
                    "-f",
                    "image2",
                    "-s",
                    "{}x{}".format(
                        int(bbox.width * self.fig.dpi),
                        int(bbox.height * self.fig.dpi),
                    ),
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-vf",
                    "palettegen",
                    "{} palette.png".format(temp_basename),
                ]
            )
            subprocess.run(
                [
                    "ffmpeg",
                    "-f",
                    "image2",
                    "-s",
                    "{}x{}".format(
                        int(bbox.width * self.fig.dpi),
                        int(bbox.height * self.fig.dpi),
                    ),
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-i",
                    "{} palette.png".format(temp_basename),
                    "-lavfi",
                    "paletteuse",
                    filename,
                ]
            )

            os.remove("{} palette.png".format(temp_basename))
            for i in range(self.shape[self.d]):
                os.remove("{} {:05d}.png".format(temp_basename, i))

        elif event.key == "v":
            filename = self.save_basename + datetime.datetime.now().strftime(
                " %Y-%m-%d at %h.%M.%S %p.mov"
            )
            temp_basename = uuid.uuid4()

            bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer())
            for i in range(self.shape[self.d]):
                self.slices[self.d] = i

                self.update_axes()
                self.update_data()
                self.fig.canvas.draw()
                self.fig.savefig(
                    "{} {:05d}.png".format(temp_basename, i),
                    format="png",
                    bbox_inches=bbox,
                    pad_inches=0,
                )

            subprocess.run(
                [
                    "ffmpeg",
                    "-framerate",
                    str(self.fps),
                    "-i",
                    "{} %05d.png".format(temp_basename),
                    "-vcodec",
                    "png",
                    filename,
                ]
            )

            for i in range(self.shape[self.d]):
                os.remove("{} {:05d}.png".format(temp_basename, i))

        elif (
            event.key
            in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "backspace"]
            and self.d != self.z
        ):
            if self.entering_slice:
                if event.key == "backspace":
                    if self.entered_slice < 10:
                        self.entering_slice = False
                    else:
                        self.entered_slice //= 10
                else:
                    self.entered_slice = self.entered_slice * 10 + int(
                        event.key
                    )
            else:
                self.entering_slice = True
                self.entered_slice = int(event.key)

            self.update_axes()
            self.fig.canvas.draw()

        elif event.key == "enter" and self.entering_slice:
            self.entering_slice = False
            if self.entered_slice < self.shape[self.d]:
                self.slices[self.d] = self.entered_slice

                self.update_data()

            self.update_axes()
            self.fig.canvas.draw()

        else:
            return

    def update_data(self):
        idx = []
        for i in range(self.ndim):
            if i == self.z:
                idx.append(slice(None, None, self.flips[i]))
            else:
                idx.append(self.slices[i])

        idx = tuple(idx)
        if idx:
            datav = sp.to_device(self.data[idx])
        else:
            datav = sp.to_device(self.data)

        # if self.z is not None:
        #     datav_dims = [self.z] + datav_dims
        coordv = sp.to_device(self.coord)

        if self.mode == "m":
            datav = np.abs(datav)
        elif self.mode == "p":
            datav = np.angle(datav)
        elif self.mode == "r":
            datav = np.real(datav)
        elif self.mode == "i":
            datav = np.imag(datav)
        elif self.mode == "l":
            eps = 1e-31
            datav = np.log(np.abs(datav) + eps)

        datav = datav.ravel()
        if self.vmin is None:
            if datav.min() == datav.max():
                self.vmin = 0
            else:
                self.vmin = datav.min()

        if self.vmax is None:
            self.vmax = datav.max()

        if self.axsc is None:
            self.axsc = self.ax.scatter(
                coordv[..., 0].ravel(),
                coordv[..., 1].ravel(),
                c=datav,
                s=1,
                linewidths=0,
                cmap="gray",
                vmin=self.vmin,
                vmax=self.vmax,
            )

        else:
            self.axsc.set_offsets(coordv.T.reshape([-1, 2]))
            self.axsc.set_color(datav)

    def update_axes(self):
        if not self.hide_axes:
            caption = "["
            for i in range(self.ndim):
                if i == self.d:
                    caption += "["
                else:
                    caption += " "

                if self.flips[i] == -1 and i == self.z:
                    caption += "-"

                if i == self.z:
                    caption += "z"
                elif i == self.d and self.entering_slice:
                    caption += str(self.entered_slice) + "_"
                else:
                    caption += str(self.slices[i])

                if i == self.d:
                    caption += "]"
                else:
                    caption += " "
            caption += "]"

            self.ax.set_title(caption)
            self.fig.suptitle(self.title)
            self.ax.xaxis.set_visible(True)
            self.ax.yaxis.set_visible(True)
            self.ax.title.set_visible(True)
        else:
            self.ax.set_title("")
            self.fig.suptitle("")
            self.ax.xaxis.set_visible(False)
            self.ax.yaxis.set_visible(False)
            self.ax.title.set_visible(False)
back to top