https://hal.archives-ouvertes.fr/hal-03737572
Raw File
evaluate.py
# Copyright 2022 - Valeo Comfort and Driving Assistance
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# The code in this file is adapted from VICReg: https://github.com/facebookresearch/vicreg


from pathlib import Path
import argparse
import json
import os
import random
import signal
import sys
import time

from torch import nn, optim
from torchvision import datasets, transforms
import torch

from src.vicreg.utils import AverageMeter, handle_sigusr1, handle_sigterm, accuracy
import src.vicreg.resnet
from src.sfrik.dataset import ImageNetSubset


def get_arguments():
    parser = argparse.ArgumentParser(
        description="Evaluate a pretrained model on ImageNet by linear probing or semi-supervised learning."
    )

    # Data
    parser.add_argument("--data-dir", type=Path, help="path to dataset")
    parser.add_argument("--subset", type=int, default=-1, help="subset of ImageNet")
    parser.add_argument("--train-percent", default=100, type=int, choices=(100, 10, 1),
                        help="size of training set in percent")
    parser.add_argument("--val-dataset", choices=["train", "val"],
                        help="Choice of the test dataset."
                             "Choose 'val' for choosing the usual ImageNet validation set."
                             "Choose 'train' for choosing a subset of the ImageNet train set.")
    parser.add_argument("--val-subset", type=int, help="Size of validation set when setting '--val-dataset train'."
                                                       "Take a fix number of images per class.")
    parser.add_argument("--only-inference", action='store_true', help="Perform only one inference on test dataset for "
                                                                      "evaluation.")

    # Checkpoint
    parser.add_argument("--pretrained", type=Path, help="path to pretrained model")
    parser.add_argument("--ckp-path", type=Path, help="path to linear head checkpoint")
    parser.add_argument("--exp-dir", type=Path, default="./exp", metavar="DIR",
                        help="path to checkpoint directory")
    parser.add_argument("--print-freq", default=100, type=int, metavar="N", help="print frequency")

    # Model
    parser.add_argument("--arch", type=str, default="resnet50")

    # Optim
    parser.add_argument("--epochs", default=100, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size")
    parser.add_argument("--lr-backbone", default=0.0, type=float, metavar="LR", help="backbone base learning rate")
    parser.add_argument("--lr-head", default=0.3, type=float, metavar="LR", help="classifier base learning rate")
    parser.add_argument("--weight-decay", default=1e-6, type=float, metavar="W", help="weight decay")
    parser.add_argument("--weights", default="freeze", type=str, choices=("finetune", "freeze"),
                        help="finetune or freeze resnet weights")

    # Running
    parser.add_argument("--workers", default=8, type=int, metavar="N", help="number of data loader workers")

    return parser


def main():
    parser = get_arguments()
    args = parser.parse_args()
    args.ngpus_per_node = torch.cuda.device_count()
    if "SLURM_JOB_ID" in os.environ:
        signal.signal(signal.SIGUSR1, handle_sigusr1)
        signal.signal(signal.SIGTERM, handle_sigterm)
    # Single-node distributed training
    args.rank = 0
    args.dist_url = f"tcp://localhost:{random.randrange(49152, 65535)}"
    args.world_size = args.ngpus_per_node
    torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)


def main_worker(gpu, args):
    # Save logs
    if args.rank == 0:
        args.exp_dir.mkdir(parents=True, exist_ok=True)
        stats_file = open(args.exp_dir / "stats.txt", "a", buffering=1)
        print(" ".join(sys.argv))

    # Set up distributed training
    args.rank += gpu
    torch.distributed.init_process_group(
        backend="nccl",
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
    )
    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True

    # Backbone
    backbone, embedding = src.vicreg.resnet.__dict__[args.arch](zero_init_residual=True)
    state_dict = torch.load(args.pretrained, map_location="cpu")
    missing_keys, unexpected_keys = backbone.load_state_dict(state_dict, strict=False)
    assert missing_keys == [] and unexpected_keys == []

    # Linear layer
    head = nn.Linear(embedding, 1000)
    head.weight.data.normal_(mean=0.0, std=0.01)
    head.bias.data.zero_()

    # Model
    model = nn.Sequential(backbone, head)
    model.cuda(gpu)
    if args.weights == "freeze":
        backbone.requires_grad_(False)
        head.requires_grad_(True)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    criterion = nn.CrossEntropyLoss().cuda(gpu)

    # Optimizer
    param_groups = [dict(params=head.parameters(), lr=args.lr_head)]
    if args.weights == "finetune":
        param_groups.append(dict(params=backbone.parameters(), lr=args.lr_backbone))
    optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    # Automatically resume from checkpoint if it exists
    if args.ckp_path is not None or (args.exp_dir / "checkpoint.pth").is_file():
        # Load checkpoint from args.ckp_path in priority.
        # If this file does not exist, load from args.exp_dir / "checkpoint.pth".
        ckp_path = args.ckp_path if args.ckp_path is not None else args.exp_dir / "checkpoint.pth"
        ckpt = torch.load(ckp_path, map_location="cpu")
        print("Load ckp from: ", args.ckp_path)
        start_epoch = ckpt["epoch"]
        best_acc = ckpt["best_acc"]
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
    else:
        assert not args.only_inference
        print("Checkpoint not found")
        start_epoch = 0
        best_acc = argparse.Namespace(top1=0, top5=0)

    # Loading train data set
    traindir = args.data_dir / "train"
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = ImageNetSubset(
        traindir,
        transform=transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), normalize]),
        subset=args.subset
    )

    # Loading test dataset
    val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
                                         normalize])
    if args.val_dataset == "val":
        valdir = args.data_dir / "val"
        val_dataset = datasets.ImageFolder(valdir, val_transforms)
    elif args.val_dataset == "train":
        valdir = args.data_dir / "train"
        val_dataset = ImageNetSubset(valdir, transform=val_transforms, start=args.subset+1, subset=args.val_subset)
    else:
        raise NotImplementedError
    print("Size of test dataset:", len(val_dataset))

    # Select 1% or 10% of labeled images for semi-supervised learning
    if args.train_percent in {1, 10}:
        train_dataset.samples = []
        with open(Path("imagenet_subsets") / f"{args.train_percent}percent.txt", "rb") as f:
            args.train_files = f.readlines()
            for fname in args.train_files:
                fname = fname.decode().strip()
                cls = fname.split("_")[0]
                train_dataset.samples.append((traindir / cls / fname, train_dataset.class_to_idx[cls]))
    print("Size of train dataset:", len(train_dataset))

    # Dataloader
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    kwargs = dict(
        batch_size=args.batch_size // args.world_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)

    # Only inference on test dataset with the current checkpoint
    start_time = time.time()
    if args.only_inference:
        model.eval()
        if args.rank == 0:
            evaluate(start_epoch, model, val_loader, gpu, best_acc, stats_file)

    # Train + evaluate over several epochs
    else:
        for epoch in range(start_epoch, args.epochs):
            # Train
            if args.weights == "finetune":
                model.train()
            elif args.weights == "freeze":
                model.eval()
            else:
                assert False
            train_sampler.set_epoch(epoch)
            for step, (images, target) in enumerate(
                train_loader, start=epoch * len(train_loader)
            ):
                output = model(images.cuda(gpu, non_blocking=True))
                loss = criterion(output, target.cuda(gpu, non_blocking=True))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if step % args.print_freq == 0:
                    torch.distributed.reduce(loss.div_(args.world_size), 0)
                    if args.rank == 0:
                        pg = optimizer.param_groups
                        lr_head = pg[0]["lr"]
                        lr_backbone = pg[1]["lr"] if len(pg) == 2 else 0
                        stats = dict(
                            epoch=epoch,
                            step=step,
                            lr_backbone=lr_backbone,
                            lr_head=lr_head,
                            loss=loss.item(),
                            time=int(time.time() - start_time),
                        )
                        print(json.dumps(stats))
                        print(json.dumps(stats), file=stats_file)

            # Evaluate on test set
            model.eval()
            if args.rank == 0:
                evaluate(epoch, model, val_loader, gpu, best_acc, stats_file)

            # Scheduler step + save checkpoint
            scheduler.step()
            if args.rank == 0:
                state = dict(
                    epoch=epoch + 1,
                    best_acc=best_acc,
                    model=model.state_dict(),
                    optimizer=optimizer.state_dict(),
                    scheduler=scheduler.state_dict(),
                )
                torch.save(state, args.exp_dir / "checkpoint.pth")


def evaluate(epoch, model, val_loader, gpu, best_acc, stats_file):
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    with torch.no_grad():
        for images, target in val_loader:
            output = model(images.cuda(gpu, non_blocking=True))
            acc1, acc5 = accuracy(
                output, target.cuda(gpu, non_blocking=True), topk=(1, 5)
            )
            top1.update(acc1[0].item(), images.size(0))
            top5.update(acc5[0].item(), images.size(0))
    best_acc.top1 = max(best_acc.top1, top1.avg)
    best_acc.top5 = max(best_acc.top5, top5.avg)
    stats = dict(
        epoch=epoch,
        acc1=top1.avg,
        acc5=top5.avg,
        best_acc1=best_acc.top1,
        best_acc5=best_acc.top5,
    )
    print(json.dumps(stats))
    print(json.dumps(stats), file=stats_file)


if __name__ == "__main__":
    main()
back to top