import os import json import argparse import datetime import numpy as np from shutil import copyfile import torch import torch.multiprocessing as mp os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' from core.trainer import Trainer from core.dist import ( get_world_size, get_local_rank, get_global_rank, get_master_ip, ) parser = argparse.ArgumentParser(description='VHII efficient') parser.add_argument('--config', default='configs/places_proposal.json', type=str) args = parser.parse_args() def main_worker(rank, config): if 'local_rank' not in config: config['local_rank'] = config['global_rank'] = rank if config['distributed']: torch.cuda.set_device(int(config['local_rank'])) torch.distributed.init_process_group(backend='nccl', init_method=config['init_method'], world_size=config['world_size'], rank=config['global_rank'], group_name='mtorch' ) print('using GPU {}-{} for training'.format( int(config['global_rank']), int(config['local_rank']))) config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])) if torch.cuda.is_available(): config['device'] = torch.device("cuda:{}".format(config['local_rank'])) else: config['device'] = 'cpu' if (not config['distributed']) or config['global_rank'] == 0: os.makedirs(config['save_dir'], exist_ok=True) config_path = os.path.join( config['save_dir'], config['config'].split('/')[-1]) if not os.path.isfile(config_path): copyfile(config['config'], config_path) print('[**] create folder {}'.format(config['save_dir'])) trainer = Trainer(config) trainer.train() if __name__ == "__main__": port = '23455' # loading configs config = json.load(open(args.config)) config['config'] = args.config print("get_local_rank(): ",get_local_rank()) print("get_global_rank(): ",get_global_rank()) # setting distributed configurations config['world_size'] = get_world_size() config['init_method'] = f"tcp://{get_master_ip()}:{port}" config['distributed'] = True if config['world_size'] > 1 else False print("get_world_size(): ",get_world_size()) # setup distributed parallel training environments if get_master_ip() == "127.0.0.2": print("PARALLEL") # manually launch distributed processes mp.spawn(main_worker, nprocs=config['world_size'], args=(config,)) else: # multiple processes have been launched by openmpi config['local_rank'] = get_local_rank() config['global_rank'] = get_global_rank() main_worker(-1, config)