Raw File
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import scipy.stats as stats\n",
    "import corner\n",
    "import lenstronomy.Util.param_util as param_util\n",
    "from baobab import bnn_priors\n",
    "from baobab.configs import BaobabConfig, tdlmc_empirical_config, gamma_empirical_config\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizing the input prior PDF in the EmpiricalBNNPrior and the resulting samples\n",
    "__Author:__ Ji Won Park\n",
    "    \n",
    "__Created:__ 9/24/19\n",
    "    \n",
    "__Last run:__ 9/29/19\n",
    "\n",
    "__Goals:__\n",
    "Plot the (marginal) distributions of the parameters sampled from the empirical BNN prior, in which parameters follow physically reasonable relations.\n",
    "\n",
    "__Before running this notebook:__\n",
    "Generate some data. At the root of the `baobab` repo, run:\n",
    "```\n",
    "generate baobab/configs/tdlmc_empirical_config.py --n_data 1000\n",
    "```\n",
    "This generates 1000 samples using `EmpiricalBNNPrior` at the location this notebook expects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg_path = tdlmc_empirical_config.__file__\n",
    "#cfg_path = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', 'data', 'baobab_configs', 'train_tdlmc_diagonal_config.py')\n",
    "cfg = BaobabConfig.from_file(cfg_path)\n",
    "#out_data_dir = os.path.join('..', '..', 'time_delay_lens_modeling_challenge', cfg.out_dir)\n",
    "out_data_dir = os.path.join('..', cfg.out_dir)\n",
    "print(out_data_dir)\n",
    "meta = pd.read_csv(os.path.join(out_data_dir, 'metadata.csv'), index_col=None)\n",
    "bnn_prior = getattr(bnn_priors, cfg.bnn_prior_class)(cfg.bnn_omega, cfg.components)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are the parameters available. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sorted(meta.columns.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add shear and ellipticity modulus and angle\n",
    "if 'external_shear_gamma_ext' in meta.columns.values:\n",
    "    gamma_ext = meta['external_shear_gamma_ext'].values\n",
    "    psi_ext = meta['external_shear_psi_ext'].values\n",
    "    gamma1, gamma2 = param_util.phi_gamma_ellipticity(psi_ext, gamma_ext)\n",
    "    meta['external_shear_gamma1'] = gamma1\n",
    "    meta['external_shear_gamma2'] = gamma2\n",
    "else:\n",
    "    gamma1 = meta['external_shear_gamma1'].values\n",
    "    gamma2 = meta['external_shear_gamma2'].values\n",
    "    psi_ext, gamma_ext = param_util.ellipticity2phi_gamma(gamma1, gamma2)\n",
    "    meta['external_shear_gamma_ext'] = gamma_ext\n",
    "    meta['external_shear_psi_ext'] = psi_ext\n",
    "for comp in cfg.components:\n",
    "    if comp in ['lens_mass', 'src_light', 'lens_light']:\n",
    "        if '{:s}_e1'.format(comp) in meta.columns.values:\n",
    "            e1 = meta['{:s}_e1'.format(comp)].values\n",
    "            e2 = meta['{:s}_e2'.format(comp)].values\n",
    "            phi, q = param_util.ellipticity2phi_q(e1, e2)\n",
    "            meta['{:s}_q'.format(comp)] = q\n",
    "            meta['{:s}_phi'.format(comp)] = phi\n",
    "        else:\n",
    "            q = meta['{:s}_q'.format(comp)].values\n",
    "            phi = meta['{:s}_phi'.format(comp)].values\n",
    "            e1, e2 = param_util.phi_q2_ellipticity(phi, q)\n",
    "            meta['{:s}_e1'.format(comp)] = e1\n",
    "            meta['{:s}_e2'.format(comp)] = e2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add source gal positional offset\n",
    "meta['src_pos_offset'] = np.sqrt(meta['src_light_center_x']**2.0 + meta['src_light_center_y']**2.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prior_samples(eval_at, component, param, unit):\n",
    "    param_key = '{:s}_{:s}'.format(component, param)\n",
    "    if param_key == 'src_light_pos_offset_x':\n",
    "        hyperparams = cfg.bnn_omega['src_light']['center_x']\n",
    "    elif param_key == 'src_light_pos_offset_y':\n",
    "        hyperparams = cfg.bnn_omega['src_light']['center_y']\n",
    "    elif (param_key == 'src_light_center_x') or (param_key == 'src_light_center_y'):\n",
    "        raise NotImplementedError(\"Use `plot_derived_quantities` instead.\")\n",
    "    elif (component, param) in bnn_prior.params_to_exclude:\n",
    "        raise NotImplementedError(\"This parameter wasn't sampled independently. Please use `plot_derived_quantities` instead.\")\n",
    "    else:\n",
    "        hyperparams = cfg.bnn_omega[component][param].copy()\n",
    "    pdf_eval = bnn_prior.eval_param_pdf(eval_at, hyperparams)\n",
    "    plt.plot(eval_at, pdf_eval, 'r-', lw=2, alpha=0.6, label='PDF')\n",
    "    binning = np.linspace(eval_at[0], eval_at[-1], 50)\n",
    "    plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')\n",
    "    print(hyperparams)\n",
    "    plt.xlabel(\"{:s} ({:s})\".format(param_key, unit))\n",
    "    plt.ylabel(\"density\")\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_derived_quantities(param_key, unit, binning=None):\n",
    "    binning = 30 if binning is None else binning\n",
    "    _ = plt.hist(meta[param_key], bins=binning, edgecolor='k', density=True, align='mid', label='sampled')\n",
    "    plt.xlabel(\"{:s} ({:s})\".format(param_key, unit))\n",
    "    plt.ylabel(\"density\")\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lens mass params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_mass_theta_E', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_x', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-0.04, 0.04, 100), 'lens_mass', 'center_y', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_mass_gamma', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_mass', 'e1', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-1.0, 1.0, 100), 'lens_mass', 'e2', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_mass_q', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_mass_phi', 'rad')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## External shear params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(0, 1.0, 100), 'external_shear', 'gamma_ext', 'no unit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-0.5*np.pi, 0.5*np.pi, 100), 'external_shear', 'psi_ext', 'rad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('external_shear_gamma1', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('external_shear_gamma2', 'dimensionless')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lens light params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_light_magnitude', 'mag')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(2, 6, 100), 'lens_light', 'n_sersic', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_light_R_sersic', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_light_e1', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_light_e2', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_light_q', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('lens_light_phi', 'rad')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Source light params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_magnitude', 'mag')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(0.0, 6.0, 100), 'src_light', 'n_sersic', 'dimensionless')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_R_sersic', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-1, 1, 100), 'src_light', 'pos_offset_x', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prior_samples(np.linspace(-1, 1, 100), 'src_light', 'pos_offset_y', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_center_x', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_center_y', 'arcsec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_e1', 'dimensionless', 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_e2', 'dimensionless', 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_q', 'dimensionless', 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('src_light_phi', 'rad', 20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AGN light params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('agn_light_magnitude', 'mag', 30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Total magnification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('total_magnification', 'dimensionless', binning=np.linspace(0, 300, 30))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Other quantities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('z_lens', 'dimensionless', 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_derived_quantities('z_src', 'dimensionless', 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meta.columns.values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pairwise distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_pairwise_dist(df, cols, fig=None):\n",
    "    n_params = len(cols)\n",
    "    plot = corner.corner(meta[cols],\n",
    "                        color='tab:blue', \n",
    "                        smooth=1.0, \n",
    "                        labels=cols,\n",
    "                        show_titles=True,\n",
    "                        fill_contours=True,\n",
    "                        levels=[0.68, 0.95, 0.997],\n",
    "                        fig=fig,\n",
    "                        range=[0.99]*n_params,\n",
    "                        hist_kwargs=dict(density=True, ))\n",
    "    return plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = ['src_pos_offset', 'total_magnification',\n",
    "        'external_shear_gamma_ext', 'external_shear_psi_ext',\n",
    "        'lens_mass_q', 'lens_mass_theta_E',\n",
    "        'src_light_q', 'src_light_R_sersic']\n",
    "_ = plot_pairwise_dist(meta, cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = ['lens_mass_gamma', 'lens_light_n_sersic' ]\n",
    "_ = plot_pairwise_dist(meta, cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3 (baobab)",
   "language": "python",
   "name": "baobab"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
back to top