https://github.com/facebookresearch/c3dpo_nrsfm
Raw File
Tip revision: aa558fd0cc10a704706a6c9704b221f7a42f5f80 authored by dnovotny on 19 October 2019, 16:09:30 UTC
removed todo
Tip revision: aa558fd
keypoints_dataset.py
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import json
import copy

import numpy as np
import torch
import pickle

from torch.utils import data
from tools.utils import NumpySeedFix, auto_init_args


class KeypointsDataset(data.Dataset):
    """
    This is a generalized class suitable
    for storing object keypoint annotations

    The input jsonfile needs to be a list of dictionaries
    (one dictionary per pose annotation) of the form:

    {
        # REQUIRED FIELDS #
        "kp_loc" : 2 x N list of keypoints
        "kp_vis" : 1 x N list of 1/0 boolean indicators
        # OPTIONAL FIELDS #
        "file_name": name of file from image_root
        "kp_loc_3d": 3 x N list of 3D GT keypoint locations in camera coords
    }

    """

    def __init__(self,
                 jsonfile=None,
                 train=True,
                 limit_to=0,
                 rand_sample=0,
                 image_root=None,
                 refresh_db=False,
                 ):

        auto_init_args(self)

        self.load_db_file()

        has_classes = 'class_mask' in self.db[0]
        if has_classes:
            self.class_db = self.get_class_db()
        else:
            self.class_db = None

    def load_db_file(self):

        print("loading data from %s" % self.jsonfile)

        ext = self.jsonfile.split('.')[-1]
        if ext == 'json':
            with open(self.jsonfile, 'r') as data_file:
                db = json.load(data_file)
        elif ext == 'pkl':
            with open(self.jsonfile, 'rb') as data_file:
                db = pickle.load(data_file)
        else:
            raise ValueError('bad extension %s' % ext)

        # the gdrive-downloaded jsons have a slightly different format:
        if 'dataset' in db:
            db = db['data']

        print("data train=%d , n frames = %d" % (self.train, len(db)))

        self.db = db

        self.restrict_images()

    def get_class_db(self):
        print('parsing class db ...')
        masks = np.stack([np.array(e['class_mask']) for e in self.db])
        unq_masks = np.unique(masks, axis=0)

        class_db = {tuple(m.tolist()): [] for m in unq_masks}
        for ei, e in enumerate(self.db):
            class_db[tuple(e['class_mask'])].append(ei)
        class_db = list(class_db.values())

        for eis in class_db:  # sanity check
            cls_array = np.stack([self.db[ei]['class_mask'] for ei in eis])
            assert ((cls_array - cls_array[0:1, :])**2).sum() <= 1e-6

        return class_db

    def restrict_images(self):

        if self.limit_to > 0:
            tgtnum = min(self.limit_to, len(self.db))
            with NumpySeedFix():
                prm = np.random.permutation(
                    len(self.db))[0:tgtnum]
            print("limitting dataset to %d samples" % tgtnum)
            self.db = [self.db[i] for i in prm]

    def __len__(self):
        if self.rand_sample > 0:
            return self.rand_sample
        else:
            return len(self.db)

    def __getitem__(self, index):

        if self.rand_sample > 0:
            if self.class_db is not None:
                # in case we have classes, sample first rand class
                # and then image index
                cls_index = np.random.randint(len(self.class_db))
                index = np.random.choice(self.class_db[cls_index])
            else:
                index = np.random.randint(len(self.db))

        entry = copy.deepcopy(self.db[index])

        # convert to torch Tensors where possible
        for fld in ('kp_loc', 'kp_vis', 'kp_loc_3d',
                    'class_mask', 'kp_defined'):
            if fld in entry:
                entry[fld] = torch.FloatTensor(entry[fld])

        if self.image_root is not None and 'image_path' in entry:
            entry['image_path'] = os.path.join(
                self.image_root, entry['image_path'])
        else:
            entry['image_path'] = '<NONE>'

        if 'p3d_info' in entry:  # filter the kp out of bbox
            bbox = torch.FloatTensor(entry['p3d_info']['bbox'])
            bbox_vis, bbox_err = bbox_kp_visibility(
                bbox, entry['kp_loc'], entry['kp_vis'])
            entry['kp_vis'] = entry['kp_vis'] * bbox_vis.float()

        # mask out invisible
        entry['kp_loc'] = entry['kp_loc'] * entry['kp_vis'][None]

        return entry


def bbox_kp_visibility(bbox, keypoints, vis):
    bx, by, bw, bh = bbox
    x = keypoints[0]
    y = keypoints[1]
    ctx_ = 0.1
    in_box = (x >= bx-ctx_*bw) * (x <= bx+bw*(1+ctx_)) * \
        (y >= by-ctx_*bh) * (y <= by+bh*(1+ctx_))

    in_box = in_box * (vis == 1)

    err = torch.stack([(bx-ctx_*bw)-x,
                       x-(bx+bw*(1+ctx_)),
                       (by-ctx_*bh)-y,
                       y-(by+bh*(1+ctx_))])
    err = torch.relu(err) * vis[None].float()
    err = torch.stack((torch.max(err[0], err[1]),
                       torch.max(err[2], err[3]))).max(dim=1)[0]

    return in_box, err
back to top