https://github.com/samet-akcay/ganomaly
Tip revision: 78da4ea9a99f5b02ab60dd651a18def929176d77 authored by Samet Akcay on 06 February 2022, 21:11:17 UTC
Update README.md
Update README.md
Tip revision: 78da4ea
visualizer.py
""" This file contains Visualizer class based on Facebook's visdom.
Returns:
Visualizer(): Visualizer class to display plots and images
"""
##
import os
import time
import numpy as np
import torchvision.utils as vutils
##
class Visualizer():
""" Visualizer wrapper based on Visdom.
Returns:
Visualizer: Class file.
"""
# pylint: disable=too-many-instance-attributes
# Reasonable.
##
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
self.win_size = 256
self.name = opt.name
self.opt = opt
if self.opt.display:
import visdom
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)
# --
# Dictionaries for plotting data and results.
self.plot_data = None
self.plot_res = None
# --
# Path to train and test directories.
self.img_dir = os.path.join(opt.outf, opt.name, 'train', 'images')
self.tst_img_dir = os.path.join(opt.outf, opt.name, 'test', 'images')
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
if not os.path.exists(self.tst_img_dir):
os.makedirs(self.tst_img_dir)
# --
# Log file.
self.log_name = os.path.join(opt.outf, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
##
@staticmethod
def normalize(inp):
"""Normalize the tensor
Args:
inp ([FloatTensor]): Input tensor
Returns:
[FloatTensor]: Normalized tensor.
"""
return (inp - inp.min()) / (inp.max() - inp.min() + 1e-5)
##
def plot_current_errors(self, epoch, counter_ratio, errors):
"""Plot current errros.
Args:
epoch (int): Current epoch
counter_ratio (float): Ratio to plot the range between two epoch.
errors (OrderedDict): Error for the current epoch.
"""
if not hasattr(self, 'plot_data') or self.plot_data is None:
self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'Epoch',
'ylabel': 'Loss'
},
win=4
)
##
def plot_performance(self, epoch, counter_ratio, performance):
""" Plot performance
Args:
epoch (int): Current epoch
counter_ratio (float): Ratio to plot the range between two epoch.
performance (OrderedDict): Performance for the current epoch.
"""
if not hasattr(self, 'plot_res') or self.plot_res is None:
self.plot_res = {'X': [], 'Y': [], 'legend': list(performance.keys())}
self.plot_res['X'].append(epoch + counter_ratio)
self.plot_res['Y'].append([performance[k] for k in self.plot_res['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_res['X'])] * len(self.plot_res['legend']), 1),
Y=np.array(self.plot_res['Y']),
opts={
'title': self.name + 'Performance Metrics',
'legend': self.plot_res['legend'],
'xlabel': 'Epoch',
'ylabel': 'Stats'
},
win=5
)
##
def print_current_errors(self, epoch, errors):
""" Print current errors.
Args:
epoch (int): Current epoch.
errors (OrderedDict): Error for the current epoch.
batch_i (int): Current batch
batch_n (int): Total Number of batches.
"""
# message = ' [%d/%d] ' % (epoch, self.opt.niter)
message = ' Loss: [%d/%d] ' % (epoch, self.opt.niter)
for key, val in errors.items():
message += '%s: %.3f ' % (key, val)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
##
def print_current_performance(self, performance, best):
""" Print current performance results.
Args:
performance ([OrderedDict]): Performance of the model
best ([int]): Best performance.
"""
message = ' '
for key, val in performance.items():
message += '%s: %.3f ' % (key, val)
message += 'max ' + self.opt.metric + ': %.3f' % best
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
def display_current_images(self, reals, fakes, fixed):
""" Display current images.
Args:
epoch (int): Current epoch
counter_ratio (float): Ratio to plot the range between two epoch.
reals ([FloatTensor]): Real Image
fakes ([FloatTensor]): Fake Image
fixed ([FloatTensor]): Fixed Fake Image
"""
reals = self.normalize(reals.cpu().numpy())
fakes = self.normalize(fakes.cpu().numpy())
fixed = self.normalize(fixed.cpu().numpy())
self.vis.images(reals, win=1, opts={'title': 'Reals'})
self.vis.images(fakes, win=2, opts={'title': 'Fakes'})
self.vis.images(fixed, win=3, opts={'title': 'Fixed'})
def save_current_images(self, epoch, reals, fakes, fixed):
""" Save images for epoch i.
Args:
epoch ([int]) : Current epoch
reals ([FloatTensor]): Real Image
fakes ([FloatTensor]): Fake Image
fixed ([FloatTensor]): Fixed Fake Image
"""
vutils.save_image(reals, '%s/reals.png' % self.img_dir, normalize=True)
vutils.save_image(fakes, '%s/fakes.png' % self.img_dir, normalize=True)
vutils.save_image(fixed, '%s/fixed_fakes_%03d.png' %(self.img_dir, epoch+1), normalize=True)