https://github.com/facebookresearch/pythia
Raw File
Tip revision: 5ad666f417684f8f1280029ac46e19dc8b5ddb0a authored by Xinlei Chen on 08 March 2019, 04:39:38 UTC
Merge pull request #39 from HuaizhengZhang/patch-1
Tip revision: 5ad666f
visualize_bbox.py
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import argparse
import cv2
import os
import csv
import base64
import sys

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
plt.switch_backend('agg')

matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.switch_backend('agg')

csv.field_size_limit(sys.maxsize)

FIELDNAMES = ['image_id', 'image_w', 'image_h',
              'num_boxes', 'boxes', 'features', 'object']


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_dir", type=str, required=True)
    parser.add_argument("--image_name_file", type=str, required=True)
    parser.add_argument("--out_dir", type=str, required=True)
    parser.add_argument("--csv_file", type=str, required=True)
    args = parser.parse_args()

    return args


def vis_detections(im, class_name, bboxes, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        '''ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')'''

    '''ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)'''
    plt.axis('off')
    plt.tight_layout()
    plt.draw()


def plot_bboxes(im, im_file, bboxes, out_dir):
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    nboexes, _ = bboxes.shape

    for i in range(nboexes):
        bbox = bboxes[i, :4]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=2.5)
        )
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

    out_file = os.path.join(out_dir, im_file.replace(".jpg", "_demo.jpg"))
    plt.savefig(out_file)


if __name__ == '__main__':
    args = parse_args()

    csvFile = args.csv_file

    image_bboxes = {}
    with open(csvFile, "r") as tsv_in_file:
        reader = csv.DictReader(tsv_in_file, delimiter='\t',
                                fieldnames=FIELDNAMES)
        for item in reader:
            item['num_boxes'] = int(item['num_boxes'])
            image_id = item['image_id']
            image_w = float(item['image_w'])
            image_h = float(item['image_h'])

            bboxes = np.frombuffer(
                base64.b64decode(item['boxes']),
                dtype=np.float32).reshape((item['num_boxes'], -1))
            image_bboxes[image_id] = bboxes

    out_dir = args.out_dir
    img_dir = args.img_dir

    with open(args.image_name_file, 'r') as f:
        content = f.readlines()

    imgs = [x.strip() for x in content]

    for image_name in imgs:
        image_path = os.path.join(img_dir, image_name+".jpg")
        im = cv2.imread(image_path)
        image_id = str(int(image_name.split('_')[2]))
        bboxes = image_bboxes[image_id]
        plot_bboxes(im, image_name, bboxes, out_dir)
back to top