https://github.com/facebookresearch/pythia
Raw File
Tip revision: 5ad666f417684f8f1280029ac46e19dc8b5ddb0a authored by Xinlei Chen on 08 March 2019, 04:39:38 UTC
Merge pull request #39 from HuaizhengZhang/patch-1
Tip revision: 5ad666f
extract_resnet152_feat.py
from __future__ import absolute_import, division, print_function, \
    unicode_literals
import argparse
from glob import glob
import numpy as np
import torch
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable
from PIL import Image
import torchvision.transforms as transforms
import os


TARGET_IMAGE_SIZE = [448, 448]
CHANNEL_MEAN = [0.485, 0.456, 0.406]
CHANNEL_STD = [0.229, 0.224, 0.225]
data_transforms = transforms.Compose([transforms.Resize(TARGET_IMAGE_SIZE),
                                      transforms.ToTensor(),
                                      transforms.Normalize(CHANNEL_MEAN,
                                                           CHANNEL_STD)])

use_cuda = torch.cuda.is_available()

# NOTE feat path "https://download.pytorch.org/models/resnet152-b121ed2d.pth"
RESNET152_MODEL = models.resnet152(pretrained=True)
RESNET152_MODEL.eval()

if use_cuda:
    RESNET152_MODEL = RESNET152_MODEL.cuda()


class ResNet152FeatModule(nn.Module):
    def __init__(self):
        super(ResNet152FeatModule, self).__init__()
        modules = list(RESNET152_MODEL.children())[:-2]
        self.feature_module = nn.Sequential(*modules)

    def forward(self, x):
        return self.feature_module(x)


_resnet_module = ResNet152FeatModule()
if use_cuda:
    _resnet_module = _resnet_module.cuda()


def extract_image_feat(img_file):
    img = Image.open(img_file).convert('RGB')
    img_transform = data_transforms(img)
    # make sure grey scale image is processed correctly
    if img_transform.shape[0] == 1:
        img_transform = img_transform.expand(3, -1, -1)
    img_var = Variable(img_transform.unsqueeze(0))
    if use_cuda:
        img_var = img_var.cuda()

    img_feat = _resnet_module(img_var)
    return img_feat


def get_image_id(image_name):
    image_id = int(image_name.split(".")[0].split("_")[-1])
    return image_id


def extract_dataset_pool5(image_dir, save_dir, total_group,
                          group_id, ext_filter):
    image_list = glob(image_dir + '/*.' + ext_filter)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for n_im, impath in enumerate(image_list):
        if (n_im + 1) % 100 == 0:
            print('processing %d / %d' % (n_im + 1, len(image_list)))
        image_name = os.path.basename(impath)
        image_id = get_image_id(image_name)
        if image_id % total_group != group_id:
            continue

        feat_name = image_name.replace(ext_filter, 'npy')
        save_path = os.path.join(save_dir, feat_name)
        tmp_lock = save_path + ".lock"

        if os.path.exists(save_path) and not os.path.exists(tmp_lock):
            continue
        if not os.path.exists(tmp_lock):
            os.makedirs(tmp_lock)

        # pool5_val = extract_image_feat(impath).permute(0, 2, 3, 1)
        try:
            pool5_val = extract_image_feat(impath).permute(0, 2, 3, 1)
        except:
            print("error for" + image_name)
            continue

        feat = pool5_val.data.cpu().numpy()
        np.save(save_path, feat)
        os.rmdir(tmp_lock)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--total_group', type=int, default=1)
    parser.add_argument('--group_id', type=int, default=0)
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--out_dir", type=str, required=True)
    parser.add_argument("--image_ext", type=str, default="jpg")

    args = parser.parse_args()

    extract_dataset_pool5(args.data_dir, args.out_dir, args.total_group,
                          args.group_id, args.image_ext)
back to top