https://hal.archives-ouvertes.fr/hal-03737572
Raw File
main_ssl_methods.py
# Copyright 2022 - Valeo Comfort and Driving Assistance
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# The code in this file is adapted from VICReg: https://github.com/facebookresearch/vicreg


from pathlib import Path
import argparse
import shutil
import time

import torch
from torch import nn
import torchvision.datasets

from src.swav.logger import create_logger, create_stats
from src.vicreg.optim import adjust_learning_rate, LARS
from src.vicreg.utils import AverageMeter
from src.vicreg.distributed import init_distributed_mode
from src.sfrik.dataset import ImageNetSubset
import src.sfrik.augmentations as aug
import src.sfrik.ssl as ssl


def get_arguments():
    parser = argparse.ArgumentParser(
        description="Self-supervised learning pretraining script for: SFRIK, SimCLR, AUH and VICReg",
    )

    # Data
    parser.add_argument("--dataset", type=str, choices=["ImageNet", "STL10"])
    parser.add_argument("--data-dir", type=Path, default="/path/to/imagenet_or_stl10", required=True,
                        help='Path to the ImageNet or STL10 dataset')
    parser.add_argument("--subset", default=-1, type=int, help="Take a fix number of images per class (example 260)")
    parser.add_argument("--augmentations", choices=["SimCLR", "VICReg", "AUH", "OBoW", "SwAV"],
                        help="choose image augmentations implemented in previous SSL methods")

    # Two-views image augmentation parameter (when setting '--augmentations SimCLR' or 'VICReg' or 'AUH')
    parser.add_argument("--size", type=int, help="Size of resized cropped image")

    # OBoW multicrop image augmentation parameters (when setting '--augmentations OBoW')
    parser.add_argument("--num_img_crops", type=int, help="Number of image crops (OBoW aug, e.g. 2)")
    parser.add_argument("--image_crop_size", type=int, help="Size of these crops (OBoW aug, e.g. 160)")
    parser.add_argument("--num_img_patches", type=int, help="Number of image patches (OBoW aug, e.g. 5)")
    parser.add_argument("--img_patch_size", type=int, help="Size of image patch (OBoW aug, e.g. 96)")

    # SwAV multicrop image augmentation parameters (when setting '--augmentations SwAV')
    parser.add_argument("--nmb_crops", type=int, nargs="+", help="list of number of crops (SwAV aug, e.g. [2, 6])")
    parser.add_argument("--size_crops", type=int, nargs="+", help="crops resolutions (SwAV aug, e.g. [224, 96])")
    parser.add_argument("--min_scale_crops", type=float, nargs="+", help="min range (SwAV aug, e.g. [0.14, 0.05])")
    parser.add_argument("--max_scale_crops", type=float, nargs="+", help="max range (SwAV aug, e.g. [1., 0.14])")
    parser.add_argument("--general_indices", type=int, nargs="+", help="idx of global views (SwAV aug, e.g. [0, 1])")

    # Additional multicrop image augmentation parameter (when setting '--augmentations OBoW' or 'SwAV')
    parser.add_argument("--reg_small", type=str, choices=["True", "False"],
                        help="apply MMD loss on: large views + small views if True; large views only if False.")

    # Checkpoints
    parser.add_argument("--exp-dir", type=Path, default="./exp",
                        help='Path to the experiment folder, where all checkpoints will be stored')
    parser.add_argument("--stats-dir", type=Path, default="./exp",
                        help='Path to the results folder, where all the logs and stats will be stored.')
    parser.add_argument("--checkpoint-freq", type=int, default=100,
                        help='Save the model every [checkpoint-freq] epochs, in exp-dir/checkpoints directory.')

    # Model
    parser.add_argument("--arch", type=str, help='Architecture of the backbone encoder network')
    parser.add_argument("--mlp", help='Size and number of layers of the MLP expander head')

    # Optim
    parser.add_argument("--epochs", type=int, help='Number of epochs')
    parser.add_argument("--batch-size", type=int,
                        help='Effective batch size (per worker batch size is [batch-size] / world-size)')
    parser.add_argument("--base-lr", type=float,
                        help='Base learning rate, effective learning after warmup is [base-lr] * [batch-size] / 256')
    parser.add_argument("--wd", type=float, default=1e-6,
                        help='Weight decay')

    # SSL method
    parser.add_argument("--ssl-method", type=str, choices=["SFRIK", "SimCLR", "AUH", "VICReg"])

    # Loss for VICReg
    parser.add_argument("--vicreg-sim-coeff", type=float,
                        help='Invariance regularization loss coefficient (e.g. 25.0)')
    parser.add_argument("--vicreg-std-coeff", type=float,
                        help='Variance regularization loss coefficient (e.g. 25.0)')
    parser.add_argument("--vicreg-cov-coeff", type=float,
                        help='Covariance regularization loss coefficient (e.g. 1.0)')

    # Loss for AUH
    parser.add_argument("--auh-sim-coeff", type=float,
                        help='Invariance regularization loss coefficient (e.g. 3000.0)')
    parser.add_argument("--auh-unif-coeff", type=float,
                        help='Uniformity regularization loss coefficient (e.g. 1.0)')
    parser.add_argument("--auh-scale", type=float,
                        help='Scale of the RBF kernel for computing pairwise energy (e.g. 2.5)')

    # Loss for SimCLR
    parser.add_argument("--simclr-temp", type=float,
                        help='Temperature parameter for NCE loss (e.g. 0.15)')

    # Loss for SFRIK
    parser.add_argument("--sfrik-sim-coeff", type=float,
                        help='Invariance regularization loss coefficient (e.g. 4000.0)')
    parser.add_argument("--sfrik-mmd-coeff", type=float,
                        help='MMD loss regularization loss coefficient (e.g. 1.0)')
    parser.add_argument("--sfrik-weights", type=str,
                        help="Weights for kernel in the form {w_1}-{w_2}-{w_3} (e.g. 1.0-40.0-40.0)")

    # Running
    parser.add_argument("--num-workers", type=int, default=10)
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')

    # Distributed
    parser.add_argument('--jean-zay', action="store_true",
                        help="set True if running on Jean Zay to use idr_torch package for distributed training")
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist-url', default='env://',
                        help='url used to set up distributed training')

    return parser


def exclude_bias_and_norm(p):
    return p.ndim == 1


if __name__ == "__main__":
    parser = get_arguments()
    args = parser.parse_args()

    torch.backends.cudnn.benchmark = True
    init_distributed_mode(args)
    gpu = torch.device(args.device)

    # Save dir and logger
    if args.rank == 0:
        args.exp_dir.mkdir(parents=True, exist_ok=True)
        args.stats_dir.mkdir(parents=True, exist_ok=True)
        checkpoints_dir = args.exp_dir / "checkpoints"
        checkpoints_dir.mkdir(parents=True, exist_ok=True)
    logger = create_logger(args.stats_dir / "train.log", rank=args.rank)
    logger.info("============ Initialized logger ============")
    logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    logger.info("The experiment checkpoints will be stored in %s\n" % args.exp_dir)
    logger.info("The experiment results (stats + logs) will be stored in %s\n" % args.stats_dir)
    train_stats = create_stats(args.stats_dir / "pretraining_stats", args, ["epoch", "loss"])

    # Augmentations
    # Two views image augmentation
    if args.augmentations in ["VICReg", "SimCLR", "AUH"]:
        transforms = eval(f"aug.{args.dataset}{args.augmentations}TrainTransform")(args.size)
        args.multicrop = None
    # Multiview augmentation like in OBoW
    elif args.augmentations == "OBoW":
        transforms = eval(f"aug.{args.dataset}OBoWTrainTransform")(
            args.num_img_crops, args.image_crop_size, args.num_img_patches, args.img_patch_size
        )
        args.multicrop = "OBoW"
    # Multiview augmentation like in SwAV
    elif args.augmentations == "SwAV":
        transforms = eval(f"aug.{args.dataset}SwAVTrainTransform")(
            args.size_crops, args.nmb_crops, args.min_scale_crops, args.max_scale_crops
        )
        args.multicrop = "SwAV"
    else:
        raise NotImplementedError

    # Dataset
    if args.dataset == "ImageNet":
        assert args.augmentations in ["VICReg", "SimCLR", "OBoW", "SwAV"]
        dataset = ImageNetSubset(args.data_dir / "train", transforms, subset=args.subset)
    elif args.dataset == "STL10":
        assert args.augmentations in ["AUH", "OBoW", "SwAV"]
        dataset = torchvision.datasets.STL10(args.data_dir, split="unlabeled", transform=transforms, download=True)
    else:
        raise NotImplementedError
    logger.info(f"Size dataset: {len(dataset)}")

    # Dataloader
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
    assert args.batch_size % args.world_size == 0
    per_device_batch_size = args.batch_size // args.world_size
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=per_device_batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        sampler=sampler,
    )

    # Model
    model = eval(f"ssl.{args.ssl_method}")(args).cuda(gpu)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    optimizer = LARS(
        model.parameters(),
        lr=0,
        weight_decay=args.wd,
        weight_decay_filter=exclude_bias_and_norm,
        lars_adaptation_filter=exclude_bias_and_norm,
    )

    # Checkpoint
    if (args.exp_dir / "last_ckp.pth").is_file():
        logger.info(f"Resuming from checkpoint found at: {args.exp_dir / 'last_ckp.pth'}")
        ckpt = torch.load(args.exp_dir / "last_ckp.pth", map_location="cpu")
        start_epoch = ckpt["epoch"]
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
    else:
        logger.info("Found no checkpoint: start from scratch")
        start_epoch = 0

    # Self-supervised training
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(start_epoch, args.epochs):
        logger.info("============ Starting epoch %i ... ============" % epoch)
        sampler.set_epoch(epoch)
        batch_time = AverageMeter("batch_time")
        data_time = AverageMeter("data_time")
        loss_meter = AverageMeter("loss")
        end = time.time()
        for step, mini_batch in enumerate(loader, start=epoch * len(loader)):
            torch.cuda.reset_peak_memory_stats()
            data_time.update(time.time() - end)
            lr = adjust_learning_rate(args, optimizer, loader, step)
            optimizer.zero_grad()

            if args.multicrop is None:
                (x, y) = mini_batch[0]
                x = x.cuda(gpu, non_blocking=True)
                y = y.cuda(gpu, non_blocking=True)
                bs = x.shape[0]
                with torch.cuda.amp.autocast():
                    loss = model.forward(x, y)
            elif args.multicrop == "OBoW":
                (img_orig, img_crops) = mini_batch[0]
                bs = img_orig.shape[0]
                img_orig = img_orig.cuda(gpu, non_blocking=True)
                img_crops = [torch.cat([img[:, i, :, :, :] for i in range(img.shape[1])], dim=0) for img in img_crops]
                img_crops = [img.cuda(gpu, non_blocking=True) for img in img_crops]
                with torch.cuda.amp.autocast():
                    loss = model.forward(img_orig, img_crops)
            elif args.multicrop == "SwAV":
                inputs = mini_batch[0]
                bs = inputs[0].shape[0]
                idx_crops = torch.cumsum(torch.unique_consecutive(
                    torch.tensor([inp.shape[-1] for inp in inputs]),
                    return_counts=True,
                )[1], 0)
                start_idx = 0
                img_views = []
                for end_idx in idx_crops:
                    img_views.append(torch.cat(inputs[start_idx: end_idx]).cuda(gpu, non_blocking=True))
                    start_idx = end_idx
                with torch.cuda.amp.autocast():
                    loss = model.forward(img_views, bs)
            else:
                raise NotImplementedError

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            loss_meter.update(loss.item(), bs)
            batch_time.update(time.time() - end)
            if args.rank == 0 and step % 30 == 0:
                logger.info(
                    "Epoch: [{0}][{1}]\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                    "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                    "Lr: {lr:.4f} \t"
                    "Max memory: {memory} MB".format(
                        epoch,
                        step,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=loss_meter,
                        lr=lr,
                        memory=torch.cuda.max_memory_allocated() // 1e6
                    )
                )
            end = time.time()

        train_stats.update((epoch, loss_meter.avg))
        if args.rank == 0:
            state = dict(
                epoch=epoch + 1,
                model=model.state_dict(),
                optimizer=optimizer.state_dict(),
            )
            torch.save(state, args.exp_dir / "last_ckp.pth")
            if epoch % args.checkpoint_freq == 0 or epoch == args.epoch - 1:
                shutil.copyfile(
                    args.exp_dir / "last_ckp.pth",
                    checkpoints_dir / f"ckp-{epoch}.pth",
                )
    if args.rank == 0:
        torch.save(model.module.backbone.state_dict(), args.exp_dir / f"pretrained_backbone.pth")
back to top