# -*- coding: utf-8 -*- """Generating the training data. This script generates the training data according to the config specifications. Example ------- To run this script, pass in the desired config file as argument:: $ generate baobab/configs/tdlmc_diagonal_config.py --n_data 1000 """ import os, sys import random import argparse import gc from types import SimpleNamespace from tqdm import tqdm import numpy as np import pandas as pd # Lenstronomy modules import lenstronomy print("Lenstronomy path being used: {:s}".format(lenstronomy.__path__[0])) from lenstronomy.LensModel.lens_model import LensModel from lenstronomy.LightModel.light_model import LightModel from lenstronomy.PointSource.point_source import PointSource # Baobab modules from baobab.configs import BaobabConfig import baobab.bnn_priors as bnn_priors from baobab.sim_utils import Imager, Selection def parse_args(): """Parse command-line arguments """ parser = argparse.ArgumentParser() parser.add_argument('config', help='Baobab config file path') parser.add_argument('--n_data', default=None, dest='n_data', type=int, help='size of dataset to generate (overrides config file)') parser.add_argument('--dest_dir', default=None, dest='dest_dir', type=str, help='destination for output folder (overrides config file)') args = parser.parse_args() # sys.argv rerouting for setuptools entry point if args is None: args = SimpleNamespace() args.config = sys.argv[0] args.n_data = sys.argv[1] args.dest_dir = sys.argv[2] return args def main(): args = parse_args() cfg = BaobabConfig.from_file(args.config) if args.n_data is not None: cfg.n_data = args.n_data if args.dest_dir is not None: cfg.destination_dir = args.dest_dir # Seed for reproducibility np.random.seed(cfg.seed) random.seed(cfg.seed) # Create data directory save_dir = cfg.out_dir if not os.path.exists(save_dir): os.makedirs(save_dir) print("Destination folder path: {:s}".format(save_dir)) else: raise OSError("Destination folder already exists.") # Instantiate density models kwargs_model = dict( lens_model_list=[cfg.bnn_omega.lens_mass.profile, cfg.bnn_omega.external_shear.profile], source_light_model_list=[cfg.bnn_omega.src_light.profile], ) lens_mass_model = LensModel(lens_model_list=kwargs_model['lens_model_list']) src_light_model = LightModel(light_model_list=kwargs_model['source_light_model_list']) if 'lens_light' in cfg.components: kwargs_model['lens_light_model_list'] = [cfg.bnn_omega.lens_light.profile] lens_light_model = LightModel(light_model_list=kwargs_model['lens_light_model_list']) else: lens_light_model = None if 'agn_light' in cfg.components: kwargs_model['point_source_model_list'] = [cfg.bnn_omega.agn_light.profile] ps_model = PointSource(point_source_type_list=kwargs_model['point_source_model_list'], fixed_magnification_list=[False]) else: ps_model = None # Instantiate Selection object selection = Selection(cfg.selection, cfg.components) # Instantiate Imager object if cfg.bnn_omega.kinematics.calculate_vel_disp or cfg.bnn_omega.time_delays.calculate_time_delays: for_cosmography = True else: for_cosmography = False imager = Imager(cfg.components, lens_mass_model, src_light_model, lens_light_model=lens_light_model, ps_model=ps_model, kwargs_numerics=cfg.numerics, min_magnification=cfg.selection.magnification.min, for_cosmography=for_cosmography, magnification_frac_err=cfg.bnn_omega.magnification.frac_error_sigma) # Initialize BNN prior if for_cosmography: kwargs_lens_eqn_solver = {'min_distance': 0.05, 'search_window': cfg.instrument['pixel_scale']*cfg.image['num_pix'], 'num_iter_max': 100} bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components, kwargs_lens_eqn_solver) else: kwargs_lens_eqn_solver = {} bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components) # Initialize empty metadata dataframe metadata = pd.DataFrame() metadata_path = os.path.join(save_dir, 'metadata.csv') current_idx = 0 # running idx of dataset pbar = tqdm(total=cfg.n_data) while current_idx < cfg.n_data: sample = bnn_prior.sample() # FIXME: sampling in batches if selection.reject_initial(sample): # select on sampled model parameters continue # Generate the image img, img_features = imager.generate_image(sample, cfg.image.num_pix, cfg.survey_object_dict) if img is None: # select on stats computed while rendering the image continue # Save image file if cfg.image.squeeze_bandpass_dimension: img = np.squeeze(img) img_filename = 'X_{0:07d}.npy'.format(current_idx) img_path = os.path.join(save_dir, img_filename) np.save(img_path, img) # Save labels meta = {} for comp in cfg.components: # Log model parameters for param_name, param_value in sample[comp].items(): meta['{:s}_{:s}'.format(comp, param_name)] = param_value if cfg.bnn_prior_class in ['EmpiricalBNNPrior', 'DiagonalCosmoBNNPrior']: # Log other stats for misc_name, misc_value in sample['misc'].items(): meta['{:s}'.format(misc_name)] = misc_value meta.update(img_features) if 'agn_light' in cfg.components: meta['x_image'] = img_features['x_image'].tolist() meta['y_image'] = img_features['y_image'].tolist() meta['n_img'] = len(img_features['y_image']) meta['magnification'] = img_features['magnification'].tolist() meta['measured_magnification'] = img_features['measured_magnification'].tolist() meta['img_filename'] = img_filename metadata = metadata.append(meta, ignore_index=True) # Export metadata.csv for the first time if current_idx == 0: metadata = metadata.reindex(sorted(metadata.columns), axis=1) # sort columns lexicographically metadata.to_csv(metadata_path, index=None) # export to csv metadata = pd.DataFrame() # init empty df for next checkpoint chunk gc.collect() # Export metadata every checkpoint interval if (current_idx + 1)%cfg.checkpoint_interval == 0: metadata.to_csv(metadata_path, index=None, mode='a', header=None) # export to csv metadata = pd.DataFrame() # init empty df for next checkpoint chunk gc.collect() # Update progress current_idx += 1 pbar.update(1) # Export to csv metadata.to_csv(metadata_path, index=None, mode='a', header=None) pbar.close() if __name__ == '__main__': main()