https://hal.archives-ouvertes.fr/hal-03737572
Raw File
extract_features.py
# Copyright 2022 - Valeo Comfort and Driving Assistance
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# The code in this file is adapted from DINO: https://github.com/facebookresearch/dino


from pathlib import Path
import argparse

from torchvision import transforms
import torch
import torch.distributed
import torch.utils.data
from torchvision.transforms import InterpolationMode

from src.swav.logger import create_logger
import src.vicreg.resnet
from src.vicreg.utils import MetricLogger
import src.vicreg.distributed as dist


def get_arguments():
    parser = argparse.ArgumentParser(
        description="Given a fixed backbone, extract features from ImageNet or STL10."
    )

    # Data
    parser.add_argument("--dataset", type=str, choices=["ImageNet", "STL10"])
    parser.add_argument("--data-dir", type=Path, help="path to dataset")
    parser.add_argument("--subset", type=int, default=-1, help="Take a fix number of images per class (example 260)"
                                                               "to construct the training set.")
    parser.add_argument("--val-dataset", choices=["train", "val"],
                        help="Choice of the test dataset."
                             "Choose 'val' for extracting features of the usual ImageNet validation set."
                             "Choose 'train' for extracting features from a subset of the ImageNet train set.")
    parser.add_argument("--val-subset", type=int, help="Size of validation set when setting '--val-dataset train'."
                                                       "Take a fix number of images per class.")

    # Model
    parser.add_argument("--arch", type=str)
    parser.add_argument("--pretrained", type=Path, help="path to pretrained model")

    # Feature extractor
    parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size")
    parser.add_argument("--exp-dir", type=Path, default="./exp", metavar="DIR",
                        help="features are saved in this directory")

    # Running
    parser.add_argument("--workers", default=8, type=int, metavar="N", help="number of data loader workers")
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')

    # Distributed
    parser.add_argument('--jean-zay', action="store_true",
                        help="set True if running on Jean Zay to use idr_torch package for distributed training")
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist-url', default='env://',
                        help='url used to set up distributed training')

    return parser


@torch.no_grad()
def extract_features(model, data_loader):
    """
    Code from DINO: https://github.com/facebookresearch/vicreg
    """
    metric_logger = MetricLogger(delimiter="  ")
    features = None
    for samples, index in metric_logger.log_every(data_loader, 10):
        samples = samples.cuda(non_blocking=True)
        index = index.cuda(non_blocking=True)
        feats = model(samples).clone()

        # init storage feature matrix
        if dist.get_rank() == 0 and features is None:
            features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
            print(f"Storing features into tensor of shape {features.shape}")

        # get indexes from all processes
        y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
        y_l = list(y_all.unbind(0))
        y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
        y_all_reduce.wait()
        index_all = torch.cat(y_l)

        # share features between processes
        feats_all = torch.empty(
            dist.get_world_size(),
            feats.size(0),
            feats.size(1),
            dtype=feats.dtype,
            device=feats.device,
        )
        output_l = list(feats_all.unbind(0))
        output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
        output_all_reduce.wait()

        # update storage feature matrix
        if dist.get_rank() == 0:
            features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
    return features


if __name__ == "__main__":
    parser = get_arguments()
    args = parser.parse_args()

    # Set up distributed mode
    torch.backends.cudnn.benchmark = True
    dist.init_distributed_mode(args)
    gpu = torch.device(args.device)

    # Save dir and logger
    if args.rank == 0:
        args.exp_dir.mkdir(parents=True, exist_ok=True)
    logger = create_logger(args.exp_dir / "train.log", rank=args.rank)
    logger.info("============ Initialized logger ============")
    logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    logger.info("The experiment directory is %s\n" % args.exp_dir)

    # Backbone
    backbone, _ = src.vicreg.resnet.__dict__[args.arch](zero_init_residual=True)
    logger.info(f"Load pretrained weights at: {args.pretrained}")
    state_dict = torch.load(args.pretrained, map_location="cpu")
    missing_keys, unexpected_keys = backbone.load_state_dict(state_dict, strict=False)
    assert missing_keys == [] and unexpected_keys == []
    logger.info(f"Extracting features from scratch")
    backbone = backbone.cuda(gpu)
    backbone = torch.nn.parallel.DistributedDataParallel(backbone, device_ids=[gpu])
    backbone.eval()

    # Dataset
    if args.dataset == "ImageNet":
        transform = transforms.Compose([
            transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        traindir = args.data_dir / "train"
        train_dataset = src.sfrik.dataset.ReturnIndexDatasetSubset(traindir, subset=args.subset, transform=transform)
        train_labels = torch.tensor([s[-1] for s in train_dataset.samples]).long()

        if args.val_dataset == "val":
            valdir = args.data_dir / "val"
            val_dataset = src.sfrik.dataset.ReturnIndexDataset(valdir, transform=transform)
        elif args.val_dataset == "train":
            valdir = args.data_dir / "train"
            val_dataset = src.sfrik.dataset.ReturnIndexDatasetSubset(valdir, start=args.subset + 1, subset=args.val_subset,
                                                                     transform=transform)
        else:
            raise NotImplementedError
        val_labels = torch.tensor([s[-1] for s in val_dataset.samples]).long()
    elif args.dataset == "STL10":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.441, 0.428, 0.387], std=[0.268, 0.261, 0.269]
            ),
        ])
        train_dataset = src.sfrik.dataset.ReturnIndexStlDataset(args.data_dir, split="train", transform=transform,
                                                                download=True)
        val_dataset = src.sfrik.dataset.ReturnIndexStlDataset(args.data_dir, split="test", transform=transform,
                                                              download=True)
        train_labels = torch.tensor(train_dataset.labels).long()
        val_labels = torch.tensor(val_dataset.labels).long()
    else:
        raise NotImplementedError
    print("Size of train dataset:", len(train_dataset))
    print("Size of test dataset:", len(val_dataset))

    # Data loader
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    kwargs = dict(
        batch_size=args.batch_size // args.world_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, sampler=train_sampler, **kwargs
    )
    val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)

    # Feature extraction
    logger.info("Extracting features for train set...")
    train_features = extract_features(backbone, train_loader)
    logger.info("Extracting features for val set...")
    val_features = extract_features(backbone, val_loader)

    # Save extracted features
    if args.rank == 0:
        torch.save(train_features, args.exp_dir / "train_features.pth")
        torch.save(val_features, args.exp_dir / "val_features.pth")
        torch.save(train_labels, args.exp_dir / "train_labels.pth")
        torch.save(val_labels, args.exp_dir / "val_labels.pth")
back to top