https://github.com/jose13579/variable-hyperparameter-image-impainting.git
Tip revision: 858fa2db1ef853540b63df5e35d0fe5c5672a8a0 authored by jose13579 on 03 July 2023, 04:52:19 UTC
refactor models
refactor models
Tip revision: 858fa2d
train.py
import os
import json
import argparse
import datetime
import numpy as np
from shutil import copyfile
import torch
import torch.multiprocessing as mp
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
from core.trainer import Trainer
from core.dist import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='VHII efficient')
parser.add_argument('--config', default='configs/places_proposal.json', type=str)
args = parser.parse_args()
def main_worker(rank, config):
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if config['distributed']:
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank'])))
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0]))
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
if (not config['distributed']) or config['global_rank'] == 0:
os.makedirs(config['save_dir'], exist_ok=True)
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1])
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
print('[**] create folder {}'.format(config['save_dir']))
trainer = Trainer(config)
trainer.train()
if __name__ == "__main__":
port = '23455'
# loading configs
config = json.load(open(args.config))
config['config'] = args.config
print("get_local_rank(): ",get_local_rank())
print("get_global_rank(): ",get_global_rank())
# setting distributed configurations
config['world_size'] = get_world_size()
config['init_method'] = f"tcp://{get_master_ip()}:{port}"
config['distributed'] = True if config['world_size'] > 1 else False
print("get_world_size(): ",get_world_size())
# setup distributed parallel training environments
if get_master_ip() == "127.0.0.2":
print("PARALLEL")
# manually launch distributed processes
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# multiple processes have been launched by openmpi
config['local_rank'] = get_local_rank()
config['global_rank'] = get_global_rank()
main_worker(-1, config)