# -*- coding: utf-8 -*-
"""Generating the training data.
This script generates the training data according to the config specifications.
Example
-------
To run this script, pass in the desired config file as argument::
$ generate baobab/configs/tdlmc_diagonal_config.py --n_data 1000
"""
import os, sys
import random
import argparse
import gc
from types import SimpleNamespace
from tqdm import tqdm
import numpy as np
import pandas as pd
# Lenstronomy modules
import lenstronomy
print("Lenstronomy path being used: {:s}".format(lenstronomy.__path__[0]))
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.LightModel.light_model import LightModel
from lenstronomy.PointSource.point_source import PointSource
# Baobab modules
from baobab.configs import BaobabConfig
import baobab.bnn_priors as bnn_priors
from baobab.sim_utils import Imager, Selection
def parse_args():
"""Parse command-line arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Baobab config file path')
parser.add_argument('--n_data', default=None, dest='n_data', type=int,
help='size of dataset to generate (overrides config file)')
parser.add_argument('--dest_dir', default=None, dest='dest_dir', type=str,
help='destination for output folder (overrides config file)')
args = parser.parse_args()
# sys.argv rerouting for setuptools entry point
if args is None:
args = SimpleNamespace()
args.config = sys.argv[0]
args.n_data = sys.argv[1]
args.dest_dir = sys.argv[2]
return args
def main():
args = parse_args()
cfg = BaobabConfig.from_file(args.config)
if args.n_data is not None:
cfg.n_data = args.n_data
if args.dest_dir is not None:
cfg.destination_dir = args.dest_dir
# Seed for reproducibility
np.random.seed(cfg.seed)
random.seed(cfg.seed)
# Create data directory
save_dir = cfg.out_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("Destination folder path: {:s}".format(save_dir))
else:
raise OSError("Destination folder already exists.")
# Instantiate density models
kwargs_model = dict(
lens_model_list=[cfg.bnn_omega.lens_mass.profile, cfg.bnn_omega.external_shear.profile],
source_light_model_list=[cfg.bnn_omega.src_light.profile],
)
lens_mass_model = LensModel(lens_model_list=kwargs_model['lens_model_list'])
src_light_model = LightModel(light_model_list=kwargs_model['source_light_model_list'])
if 'lens_light' in cfg.components:
kwargs_model['lens_light_model_list'] = [cfg.bnn_omega.lens_light.profile]
lens_light_model = LightModel(light_model_list=kwargs_model['lens_light_model_list'])
else:
lens_light_model = None
if 'agn_light' in cfg.components:
kwargs_model['point_source_model_list'] = [cfg.bnn_omega.agn_light.profile]
ps_model = PointSource(point_source_type_list=kwargs_model['point_source_model_list'], fixed_magnification_list=[False])
else:
ps_model = None
# Instantiate Selection object
selection = Selection(cfg.selection, cfg.components)
# Instantiate Imager object
if cfg.bnn_omega.kinematics.calculate_vel_disp or cfg.bnn_omega.time_delays.calculate_time_delays:
for_cosmography = True
else:
for_cosmography = False
imager = Imager(cfg.components, lens_mass_model, src_light_model, lens_light_model=lens_light_model, ps_model=ps_model, kwargs_numerics=cfg.numerics, min_magnification=cfg.selection.magnification.min, for_cosmography=for_cosmography, magnification_frac_err=cfg.bnn_omega.magnification.frac_error_sigma)
# Initialize BNN prior
if for_cosmography:
kwargs_lens_eqn_solver = {'min_distance': 0.05, 'search_window': cfg.instrument['pixel_scale']*cfg.image['num_pix'], 'num_iter_max': 100}
bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components, kwargs_lens_eqn_solver)
else:
kwargs_lens_eqn_solver = {}
bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components)
# Initialize empty metadata dataframe
metadata = pd.DataFrame()
metadata_path = os.path.join(save_dir, 'metadata.csv')
current_idx = 0 # running idx of dataset
pbar = tqdm(total=cfg.n_data)
while current_idx < cfg.n_data:
sample = bnn_prior.sample() # FIXME: sampling in batches
if selection.reject_initial(sample): # select on sampled model parameters
continue
# Generate the image
img, img_features = imager.generate_image(sample, cfg.image.num_pix, cfg.survey_object_dict)
if img is None: # select on stats computed while rendering the image
continue
# Save image file
if cfg.image.squeeze_bandpass_dimension:
img = np.squeeze(img)
img_filename = 'X_{0:07d}.npy'.format(current_idx)
img_path = os.path.join(save_dir, img_filename)
np.save(img_path, img)
# Save labels
meta = {}
for comp in cfg.components: # Log model parameters
for param_name, param_value in sample[comp].items():
meta['{:s}_{:s}'.format(comp, param_name)] = param_value
if cfg.bnn_prior_class in ['EmpiricalBNNPrior', 'DiagonalCosmoBNNPrior']: # Log other stats
for misc_name, misc_value in sample['misc'].items():
meta['{:s}'.format(misc_name)] = misc_value
meta.update(img_features)
if 'agn_light' in cfg.components:
meta['x_image'] = img_features['x_image'].tolist()
meta['y_image'] = img_features['y_image'].tolist()
meta['n_img'] = len(img_features['y_image'])
meta['magnification'] = img_features['magnification'].tolist()
meta['measured_magnification'] = img_features['measured_magnification'].tolist()
meta['img_filename'] = img_filename
metadata = metadata.append(meta, ignore_index=True)
# Export metadata.csv for the first time
if current_idx == 0:
metadata = metadata.reindex(sorted(metadata.columns), axis=1) # sort columns lexicographically
metadata.to_csv(metadata_path, index=None) # export to csv
metadata = pd.DataFrame() # init empty df for next checkpoint chunk
gc.collect()
# Export metadata every checkpoint interval
if (current_idx + 1)%cfg.checkpoint_interval == 0:
metadata.to_csv(metadata_path, index=None, mode='a', header=None) # export to csv
metadata = pd.DataFrame() # init empty df for next checkpoint chunk
gc.collect()
# Update progress
current_idx += 1
pbar.update(1)
# Export to csv
metadata.to_csv(metadata_path, index=None, mode='a', header=None)
pbar.close()
if __name__ == '__main__':
main()