# Copyright 2022 - Valeo Comfort and Driving Assistance # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # The code in this file is adapted from DINO: https://github.com/facebookresearch/dino from pathlib import Path import argparse from torchvision import transforms import torch import torch.distributed import torch.utils.data from torchvision.transforms import InterpolationMode from src.swav.logger import create_logger import src.vicreg.resnet from src.vicreg.utils import MetricLogger import src.vicreg.distributed as dist def get_arguments(): parser = argparse.ArgumentParser( description="Given a fixed backbone, extract features from ImageNet or STL10." ) # Data parser.add_argument("--dataset", type=str, choices=["ImageNet", "STL10"]) parser.add_argument("--data-dir", type=Path, help="path to dataset") parser.add_argument("--subset", type=int, default=-1, help="Take a fix number of images per class (example 260)" "to construct the training set.") parser.add_argument("--val-dataset", choices=["train", "val"], help="Choice of the test dataset." "Choose 'val' for extracting features of the usual ImageNet validation set." "Choose 'train' for extracting features from a subset of the ImageNet train set.") parser.add_argument("--val-subset", type=int, help="Size of validation set when setting '--val-dataset train'." "Take a fix number of images per class.") # Model parser.add_argument("--arch", type=str) parser.add_argument("--pretrained", type=Path, help="path to pretrained model") # Feature extractor parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size") parser.add_argument("--exp-dir", type=Path, default="./exp", metavar="DIR", help="features are saved in this directory") # Running parser.add_argument("--workers", default=8, type=int, metavar="N", help="number of data loader workers") parser.add_argument('--device', default='cuda', help='device to use for training / testing') # Distributed parser.add_argument('--jean-zay', action="store_true", help="set True if running on Jean Zay to use idr_torch package for distributed training") parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') return parser @torch.no_grad() def extract_features(model, data_loader): """ Code from DINO: https://github.com/facebookresearch/vicreg """ metric_logger = MetricLogger(delimiter=" ") features = None for samples, index in metric_logger.log_every(data_loader, 10): samples = samples.cuda(non_blocking=True) index = index.cuda(non_blocking=True) feats = model(samples).clone() # init storage feature matrix if dist.get_rank() == 0 and features is None: features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) print(f"Storing features into tensor of shape {features.shape}") # get indexes from all processes y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) y_l = list(y_all.unbind(0)) y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) y_all_reduce.wait() index_all = torch.cat(y_l) # share features between processes feats_all = torch.empty( dist.get_world_size(), feats.size(0), feats.size(1), dtype=feats.dtype, device=feats.device, ) output_l = list(feats_all.unbind(0)) output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) output_all_reduce.wait() # update storage feature matrix if dist.get_rank() == 0: features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) return features if __name__ == "__main__": parser = get_arguments() args = parser.parse_args() # Set up distributed mode torch.backends.cudnn.benchmark = True dist.init_distributed_mode(args) gpu = torch.device(args.device) # Save dir and logger if args.rank == 0: args.exp_dir.mkdir(parents=True, exist_ok=True) logger = create_logger(args.exp_dir / "train.log", rank=args.rank) logger.info("============ Initialized logger ============") logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) logger.info("The experiment directory is %s\n" % args.exp_dir) # Backbone backbone, _ = src.vicreg.resnet.__dict__[args.arch](zero_init_residual=True) logger.info(f"Load pretrained weights at: {args.pretrained}") state_dict = torch.load(args.pretrained, map_location="cpu") missing_keys, unexpected_keys = backbone.load_state_dict(state_dict, strict=False) assert missing_keys == [] and unexpected_keys == [] logger.info(f"Extracting features from scratch") backbone = backbone.cuda(gpu) backbone = torch.nn.parallel.DistributedDataParallel(backbone, device_ids=[gpu]) backbone.eval() # Dataset if args.dataset == "ImageNet": transform = transforms.Compose([ transforms.Resize(256, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) traindir = args.data_dir / "train" train_dataset = src.sfrik.dataset.ReturnIndexDatasetSubset(traindir, subset=args.subset, transform=transform) train_labels = torch.tensor([s[-1] for s in train_dataset.samples]).long() if args.val_dataset == "val": valdir = args.data_dir / "val" val_dataset = src.sfrik.dataset.ReturnIndexDataset(valdir, transform=transform) elif args.val_dataset == "train": valdir = args.data_dir / "train" val_dataset = src.sfrik.dataset.ReturnIndexDatasetSubset(valdir, start=args.subset + 1, subset=args.val_subset, transform=transform) else: raise NotImplementedError val_labels = torch.tensor([s[-1] for s in val_dataset.samples]).long() elif args.dataset == "STL10": transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=[0.441, 0.428, 0.387], std=[0.268, 0.261, 0.269] ), ]) train_dataset = src.sfrik.dataset.ReturnIndexStlDataset(args.data_dir, split="train", transform=transform, download=True) val_dataset = src.sfrik.dataset.ReturnIndexStlDataset(args.data_dir, split="test", transform=transform, download=True) train_labels = torch.tensor(train_dataset.labels).long() val_labels = torch.tensor(val_dataset.labels).long() else: raise NotImplementedError print("Size of train dataset:", len(train_dataset)) print("Size of test dataset:", len(val_dataset)) # Data loader train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) kwargs = dict( batch_size=args.batch_size // args.world_size, num_workers=args.workers, pin_memory=True, drop_last=False ) train_loader = torch.utils.data.DataLoader( train_dataset, sampler=train_sampler, **kwargs ) val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs) # Feature extraction logger.info("Extracting features for train set...") train_features = extract_features(backbone, train_loader) logger.info("Extracting features for val set...") val_features = extract_features(backbone, val_loader) # Save extracted features if args.rank == 0: torch.save(train_features, args.exp_dir / "train_features.pth") torch.save(val_features, args.exp_dir / "val_features.pth") torch.save(train_labels, args.exp_dir / "train_labels.pth") torch.save(val_labels, args.exp_dir / "val_labels.pth")