https://github.com/jose13579/variable-hyperparameter-image-impainting.git
Raw File
Tip revision: 858fa2db1ef853540b63df5e35d0fe5c5672a8a0 authored by jose13579 on 03 July 2023, 04:52:19 UTC
refactor models
Tip revision: 858fa2d
train.py
import os
import json
import argparse
import datetime
import numpy as np
from shutil import copyfile
import torch
import torch.multiprocessing as mp
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

from core.trainer import Trainer
from core.dist import (
    get_world_size,
    get_local_rank,
    get_global_rank,
    get_master_ip,
)
parser = argparse.ArgumentParser(description='VHII efficient')
parser.add_argument('--config', default='configs/places_proposal.json', type=str)
args = parser.parse_args()


def main_worker(rank, config):
    if 'local_rank' not in config:
        config['local_rank'] = config['global_rank'] = rank
    if config['distributed']:
        torch.cuda.set_device(int(config['local_rank']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=config['init_method'],
                                             world_size=config['world_size'],
                                             rank=config['global_rank'],
                                             group_name='mtorch'
                                             )
        print('using GPU {}-{} for training'.format(
            int(config['global_rank']), int(config['local_rank'])))

    config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0]))
    if torch.cuda.is_available(): 
        config['device'] = torch.device("cuda:{}".format(config['local_rank']))
    else: 
        config['device'] = 'cpu'

    if (not config['distributed']) or config['global_rank'] == 0:
        os.makedirs(config['save_dir'], exist_ok=True)
        config_path = os.path.join(
            config['save_dir'], config['config'].split('/')[-1])
        if not os.path.isfile(config_path):
            copyfile(config['config'], config_path)
        print('[**] create folder {}'.format(config['save_dir']))
    
    trainer = Trainer(config)
    trainer.train()


if __name__ == "__main__":
    port = '23455'
    # loading configs
    config = json.load(open(args.config))
    config['config'] = args.config
    
    print("get_local_rank(): ",get_local_rank())
    print("get_global_rank(): ",get_global_rank())

    # setting distributed configurations
    config['world_size'] = get_world_size()
    config['init_method'] = f"tcp://{get_master_ip()}:{port}"
    config['distributed'] = True if config['world_size'] > 1 else False

    print("get_world_size(): ",get_world_size())
    # setup distributed parallel training environments
    if get_master_ip() == "127.0.0.2":
        print("PARALLEL")
        # manually launch distributed processes 
        mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
    else:
        # multiple processes have been launched by openmpi 
        config['local_rank'] = get_local_rank()
        config['global_rank'] = get_global_rank()
        main_worker(-1, config)
back to top