In [None]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats as stats
import corner
import lenstronomy.Util.param_util as param_util
from baobab import bnn_priors
from baobab.configs import BaobabConfig, tdlmc_cov_config, gamma_cov_config
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Visualizing the input prior PDF in the CovBNNPrior and the resulting samples
__Author:__ Ji Won Park
 
__Created:__ 8/30/19
 
__Last run:__ 9/05/19

__Goals:__
Plot the (marginal) distributions of the parameters sampled from the covariate BNN prior, in which parameters take a multivariate (log)normal distribution.

__Before running this notebook:__
Generate some data. At the root of the `baobab` repo, run:
```
generate baobab/configs/tdlmc_cov_config.py --n_data 1000
```
This generates 1000 samples using `CovBNNPrior` at the location this notebook expects.

In [None]:
# TODO add description

In [None]:
cfg_path = tdlmc_cov_config.__file__
#cfg_path = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', 'data', 'baobab_configs', 'train_tdlmc_diagonal_config.py')
cfg = BaobabConfig.from_file(cfg_path)
#out_data_dir = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', cfg.out_dir)
out_data_dir = os.path.join('..', cfg.out_dir)
print(out_data_dir)
meta = pd.read_csv(os.path.join(out_data_dir, 'metadata.csv'), index_col=None)
bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components)

Here are the parameters available. 

In [None]:
sorted(meta.columns.values)

In [None]:
# Add shear and ellipticity modulus and angle
if 'external_shear_gamma_ext' in meta.columns.values:
 gamma_ext = meta['external_shear_gamma_ext'].values
 psi_ext = meta['external_shear_psi_ext'].values
 gamma1, gamma2 = param_util.phi_gamma_ellipticity(psi_ext, gamma_ext)
 meta['external_shear_gamma1'] = gamma1
 meta['external_shear_gamma2'] = gamma2
else:
 gamma1 = meta['external_shear_gamma1'].values
 gamma2 = meta['external_shear_gamma2'].values
 psi_ext, gamma_ext = param_util.ellipticity2phi_gamma(gamma1, gamma2)
 meta['external_shear_gamma_ext'] = gamma_ext
 meta['external_shear_psi_ext'] = psi_ext
for comp in cfg.components:
 if comp in ['lens_mass', 'src_light', 'lens_light']:
 if '{:s}_e1'.format(comp) in meta.columns.values:
 e1 = meta['{:s}_e1'.format(comp)].values
 e2 = meta['{:s}_e2'.format(comp)].values
 phi, q = param_util.ellipticity2phi_q(e1, e2)
 meta['{:s}_q'.format(comp)] = q
 meta['{:s}_phi'.format(comp)] = phi
 else:
 q = meta['{:s}_q'.format(comp)].values
 phi = meta['{:s}_phi'.format(comp)].values
 e1, e2 = param_util.phi_q2_ellipticity(phi, q)
 meta['{:s}_e1'.format(comp)] = e1
 meta['{:s}_e2'.format(comp)] = e2

In [None]:
# Add source gal positional offset
meta['src_pos_offset'] = np.sqrt(meta['src_light_center_x']**2.0 + meta['src_light_center_y']**2.0)

In [None]:
def plot_prior_samples(eval_at, component, param, unit):
 param_key = '{:s}_{:s}'.format(component, param)
 if param_key == 'src_light_pos_offset_x':
 hyperparams = cfg.bnn_omega['src_light']['center_x']
 elif param_key == 'src_light_pos_offset_y':
 hyperparams = cfg.bnn_omega['src_light']['center_y']
 elif (param_key == 'src_light_center_x') or (param_key == 'src_light_center_y'):
 raise NotImplementedError("Use `plot_derived_quantities` instead.")
 elif (component, param) in bnn_prior.params_to_exclude:
 raise NotImplementedError("This parameter wasn't sampled independently. Please use `plot_derived_quantities` instead.")
 else:
 hyperparams = cfg.bnn_omega[component][param].copy()
 pdf_eval = bnn_prior.eval_param_pdf(eval_at, hyperparams)
 plt.plot(eval_at, pdf_eval, 'r-', lw=2, alpha=0.6, label='PDF')
 binning = np.linspace(eval_at[0], eval_at[-1], 50)
 plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')
 print(hyperparams)
 plt.xlabel("{:s} ({:s})".format(param_key, unit))
 plt.ylabel("density")
 plt.legend()

In [None]:
def plot_derived_quantities(param_key, unit, binning=None):
 binning = 30 if binning is None else binning
 _ = plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')
 plt.xlabel("{:s} ({:s})".format(param_key, unit))
 plt.ylabel("density")
 plt.legend()

## Lens mass params

In [None]:
plot_prior_samples(np.linspace(0.5, 1.5, 100), 'lens_mass', 'theta_E', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_x', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_y', 'arcsec')

In [None]:
plot_derived_quantities('lens_mass_gamma', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_mass', 'e1', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_mass', 'e2', 'dimensionless')

In [None]:
plot_derived_quantities('lens_mass_q', 'dimensionless')

In [None]:
plot_derived_quantities('lens_mass_phi', 'rad')

## External shear params

In [None]:
plot_prior_samples(np.linspace(0, 1.0, 100), 'external_shear', 'gamma_ext', 'no unit')

In [None]:
plot_prior_samples(np.linspace(-0.5*np.pi, 0.5*np.pi, 100), 'external_shear', 'psi_ext', 'rad')

In [None]:
plot_derived_quantities('external_shear_gamma1', 'dimensionless')

In [None]:
plot_derived_quantities('external_shear_gamma2', 'dimensionless')

## Lens light params

In [None]:
plot_derived_quantities('lens_light_magnitude', 'mag')

In [None]:
plot_derived_quantities('lens_light_n_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.0, 2.0, 100), 'lens_light', 'R_sersic', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_light', 'e1', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_light', 'e2', 'dimensionless')

In [None]:
plot_derived_quantities('lens_light_q', 'dimensionless')

In [None]:
plot_derived_quantities('lens_light_phi', 'rad')

## Source light params

In [None]:
plot_derived_quantities('src_light_magnitude', 'mag')

In [None]:
plot_prior_samples(np.linspace(0.0, 6.0, 100), 'src_light', 'n_sersic', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(0.0, 2.0, 100), 'src_light', 'R_sersic', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-1, 1, 100), 'src_light', 'pos_offset_x', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-1, 1, 100), 'src_light', 'pos_offset_y', 'arcsec')

In [None]:
plot_derived_quantities('src_light_center_x', 'arcsec')

In [None]:
plot_derived_quantities('src_light_center_y', 'arcsec')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'src_light', 'e1', 'dimensionless')

In [None]:
plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'src_light', 'e2', 'dimensionless')

In [None]:
plot_derived_quantities('src_light_q', 'dimensionless')

In [None]:
plot_derived_quantities('src_light_phi', 'rad')

## AGN light params

In [None]:
plot_derived_quantities('agn_light_magnitude', 'mag')

## Total magnification

In [None]:
plot_derived_quantities('total_magnification', 'dimensionless', binning=np.linspace(0, 300, 30))

## Pairwise distributions

In [None]:
def plot_pairwise_dist(df, cols, fig=None):
 n_params = len(cols)
 plot = corner.corner(meta[cols],
 color='tab:blue', 
 smooth=1.0, 
 labels=cols,
 show_titles=True,
 fill_contours=True,
 levels=[0.68, 0.95, 0.997],
 fig=fig,
 range=[0.99]*n_params,
 hist_kwargs=dict(density=True, ))
 return plot

In [None]:
cols = ['src_pos_offset', 'total_magnification',
 'external_shear_gamma_ext', 'external_shear_psi_ext',
 'lens_mass_q', 'lens_mass_theta_E',
 'src_light_q', ]
_ = plot_pairwise_dist(meta, cols)

In [None]:
cols = ['lens_mass_gamma', 'lens_light_n_sersic' ]
_ = plot_pairwise_dist(meta, cols)