https://github.com/MEPP-team/SAE
Raw File
Tip revision: ab5ce43f5fc967383d08795669ce791776b4bcd4 authored by Clément Lemeunier on 01 November 2023, 15:42:55 UTC
specify pytorch-lightning version
Tip revision: ab5ce43
train.py
import argparse
import os
import webdataset as wds
import numpy as np
import torch
import json
import pytorch_lightning as pl

from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from models.utils.init_weights import init_weights


def get_opt(job_id, load):
    if load:
        print('Loading options with job_id')
        with open('checkpoints/' + str(job_id) + '/infos.json', "r") as outfile:
            opt = json.load(outfile)
    else:
        print('Loading default options for new job')

        with open('default_options.json', "r") as outfile:
            opt = json.load(outfile)

        parser = argparse.ArgumentParser()

        parser.add_argument('--job_id', required=True)

        # so we can pass other default options as program argument
        for key, value in opt.items():
            parser.add_argument('--' + key, default=value, type=type(value))

        opt = vars(parser.parse_args())

        try:
            os.mkdir('checkpoints/' + str(opt['job_id']))
        except:
            print('Folder for job ' + str(opt['job_id']) + ' already exists.')
            exit()

        with open('checkpoints/' + str(opt['job_id']) + '/infos.json', "w") as outfile:
            json.dump(opt, outfile, sort_keys=True, indent=4)

    print('Options:\n', json.dumps(opt, sort_keys=True, indent=4), end="\n\n")

    return opt


def get_dataloader(opt, dataset_type, batch_size, wanted_index=None, shuffle=True):
    dataloader_slice = wanted_index is not None

    url = opt['path'] + "/" + dataset_type + ".tar"

    if dataset_type == 'train':
        tuple_type = ("infos.json", "coeffs.pth")

        if dataloader_slice:
            dataset = wds.WebDataset(url).slice(wanted_index, opt['nb_train']).decode().to_tuple(*tuple_type)
        else:
            if shuffle:
                dataset = wds.WebDataset(url).shuffle(opt['nb_train'], initial=opt['nb_train']).decode().to_tuple(*tuple_type)
            else:
                dataset = wds.WebDataset(url).decode().to_tuple(*tuple_type)
    else:
        tuple_type = ("infos.json", "coeffs.pth", "vertices.pth")
        if dataloader_slice:
            dataset = wds.WebDataset(url).slice(wanted_index, opt['nb_evals']).decode().to_tuple(*tuple_type)
        else:
            dataset = wds.WebDataset(url).decode().to_tuple(*tuple_type)

    return wds.WebLoader(
        dataset,
        num_workers=opt["num_workers"],
        batch_size=batch_size,
        pin_memory=True,
    )


def load_trainer(job_id=None, profiler="simple"):
    if job_id is None:
        load = False
    else:
        load = True

    seed_everything(42, workers=True)
    print()

    opt = get_opt(job_id, load)

    with open(opt['path'] + '/infos.json', "r") as f:
        opt_dataset = json.load(f)

    for key, value in opt_dataset.items():
        if key in opt:
            continue
        opt[key] = value

    if opt['loss_type'] == "MSE":
        opt['loss'] = torch.nn.MSELoss()
    else:
        print('Loss type not implemented.')
        exit()

    opt['nb_vertices'] = 6890
    opt['TRIV'] = np.loadtxt(opt['path'] + '/../mesh.triv', dtype='int32') - 1

    path_evecs = opt['path'] + '/../mesh.evecs_6890'

    print('Loading eigen vectors...', end='')
    opt['evecs'] = torch.from_numpy(np.load(path_evecs)).float().to(opt['device'])[:, :opt['nb_freq']]
    print(' done.')

    # datasets
    opt['dataloader_train'] = get_dataloader(opt, 'train', opt['train_batch_size'])
    opt['dataloader_test'] = get_dataloader(opt, 'test', opt['test_batch_size'])

    # model
    exec("from models." + opt['model_type'] + " import " + opt['model_type'])
    opt['model'] = eval(opt['model_type'])(opt).to(opt['device'])
    opt['model'].apply(init_weights)

    nb_params = sum(p.numel() for p in opt['model'].parameters() if p.requires_grad)

    print('Number of parameters:', str(nb_params), end="\n\n")

    # pytorch lightning
    logger = TensorBoardLogger(
        save_dir="tb_logs",
        name="",
        version=str(opt['job_id']),
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints/' + str(opt['job_id']),
        filename='{epoch:02d}_{spatial_validation_loss:.2f}',
        every_n_epochs=opt['check_val_every_n_epoch'],
        save_top_k=0,
        save_last=True,
        monitor='spatial_validation_loss',
        mode='min'
    )

    progress_bar = TQDMProgressBar(
        refresh_rate=1
    )

    pl_trainer = pl.Trainer(
        accelerator='gpu', devices=1,
        profiler=profiler,
        max_epochs=opt['num_iterations'],
        check_val_every_n_epoch=opt['check_val_every_n_epoch'],
        logger=logger,
        precision=32,
        default_root_dir='checkpoints/',
        callbacks=[checkpoint_callback, progress_bar],
        deterministic=False,
        benchmark=True,
    )

    if load:
        best_checkpoint_filename = 'checkpoints/' + str(job_id) + "/last.ckpt"
        print('Loading checkpoint:', best_checkpoint_filename)
        opt['model'] = opt['model'].load_from_checkpoint(best_checkpoint_filename, opt=opt).to(opt["device"])

    return pl_trainer, opt


if __name__ == "__main__":
    trainer, opt = load_trainer()  # without arguments for new job

    trainer.fit(
        opt['model'],
        opt['dataloader_train'],
        opt['dataloader_test']
    )
back to top