Raw File
import numpy as np
import scipy.stats as stats
import astropy.units as u
from addict import Dict
import lenstronomy.Util.param_util as param_util
from .base_bnn_prior import BaseBNNPrior
from .base_cosmo_bnn_prior import BaseCosmoBNNPrior
from . import kinematics_models, parameter_models

class EmpiricalBNNPrior(BaseBNNPrior, BaseCosmoBNNPrior):
    """BNN prior that encodes physical correlations between parameters

    """
    def __init__(self, bnn_omega, components):
        """
        Note
        ----
        The dictionary attributes are copies of the config corresponding to each component.
        The number of attributes depends on the number of components.

        Attributes
        ----------
        bnn_omega : dict
            copy of `cfg.bnn_omega`
        components : list
            list of components, e.g. `lens_mass`

        """
        BaseBNNPrior.__init__(self, bnn_omega, components)
        BaseCosmoBNNPrior.__init__(self, bnn_omega)

        self.params_to_exclude = [('lens_mass', 'theta_E'), ('lens_mass', 'gamma'), ('lens_mass', 'q'), ('lens_light', 'magnitude'), ('lens_light', 'R_sersic'), ('lens_light', 'q'),
        ('src_light', 'magnitude'), ('src_light', 'R_sersic'),
        ('agn_light', 'magnitude')
        ]
        self.set_params_list(self.params_to_exclude)
        self.set_comps_qphi_to_e1e2()

        # TODO: AGN parameters are sampled even when it's not rendered on image (not in self.components)
        self._define_kinematics_models(self.kinematics)
        self._define_parameter_models(self.lens_mass, self.lens_light, self.src_light, self.agn_light)

    def _define_kinematics_models(self, kinematics_cfg):
        """Set the empirical models related to the kinematics, based on the config

        Parameters
        ----------
        kinematics_cfg : dict
            Copy of `cfg.bnn_omega.kinematics`

        """
        self.vel_disp_function = getattr(kinematics_models, kinematics_cfg.vel_disp.model)

    def _define_parameter_models(self, lens_mass_cfg, lens_light_cfg, src_light_cfg, agn_light_cfg):
        """Set the empirical models, with which to generate all the training samples,
        based on each component key in `cfg.bnn_omega`

        Parameters
        ----------
        lens_mass_cfg : dict
            Copy of `cfg.bnn_omega.lens_mass`
        lens_light_cfg : dict
            Copy of `cfg.bnn_omega.lens_light`
        src_light_cfg : dict
            Copy of `cfg.bnn_omega.src_light`
        agn_light_cfg : dict
            Copy of `cfg.bnn_omega.agn_light`

        """
        # lens_mass
        self.gamma_model = getattr(parameter_models, lens_mass_cfg.gamma.model)(**lens_mass_cfg.gamma.model_kwargs).get_gamma_from_vel_disp
        self.theta_E_model = getattr(parameter_models, lens_mass_cfg.theta_E.model)
        self.lens_axis_ratio_model = getattr(parameter_models, lens_light_cfg.q.model)(**lens_light_cfg.q.model_kwargs).get_axis_ratio

        # lens_light
        self.lens_luminosity_model = getattr(parameter_models, lens_light_cfg.magnitude.model)(**lens_light_cfg.magnitude.model_kwargs).get_luminosity
        self.lens_light_size_model = getattr(parameter_models, lens_light_cfg.R_sersic.model)(**lens_light_cfg.R_sersic.model_kwargs).get_effective_radius
        #self.lens_axis_ratio_model = getattr(parameter_models, lens_light_cfg.q.model)(**lens_light_cfg.q.model_kwargs).get_axis_ratio

        # src_light
        self.src_luminosity_model = getattr(parameter_models, src_light_cfg.magnitude.model)
        self.src_light_size_model = getattr(parameter_models, src_light_cfg.R_sersic.model)

        # agn_light
        self.agn_luminosity_model = getattr(parameter_models, agn_light_cfg.magnitude.model)(**agn_light_cfg.magnitude.model_kwargs).sample_agn_luminosity

    def sample_vel_disp(self, vel_disp_cfg):
        """Sample velocity dispersion from the config-specified model,
        on a grid with the range and resolution specified in the config

        Parameters
        ----------
        vel_disp_cfg : dict
            Copy of `cfg.bnn_omega.kinematics.vel_disp`

        Returns
        -------
        float
            a realization of velocity dispersion

        """
        vel_disp_grid = np.arange(**vel_disp_cfg.grid)
        dn = self.vel_disp_function(vel_disp_grid)
        dn_normed = dn/np.sum(dn)
        sampled_vel_disp = np.random.choice(vel_disp_grid, None, replace=True, p=dn_normed)
        return sampled_vel_disp

    def get_lens_absolute_magnitude(self, vel_disp):
        """Get the lens absolute magnitude from the Faber-Jackson relation
        given the realized velocity dispersion, with some scatter

        Parameters
        ----------
        vel_disp : float
            the velocity dispersion in km/s

        Returns
        -------
        float
            the V-band absolute magnitude

        """
        log_L_V = self.lens_luminosity_model(vel_disp)
        M_V_sol = 4.84
        M_V = -2.5 * log_L_V + M_V_sol
        return M_V        

    def get_lens_apparent_magnitude(self, M_lens, z_lens):
        """Get the lens apparent magnitude from the Faber-Jackson relation
        given the realized velocity dispersion, with some scatter

        Parameters
        ----------
        M_lens : float
            the V-band absolute magnitude of lens
        z_lens : float
            the lens redshift

        Note
        ----
        Does not account for peculiar velocity or dust. K-correction is approximate and implicit,
        as the absolute magnitude is in the V-band (480nm ~ 650nm) and, for z ~ 2-3, this portion 
        of the SED roughly lands in the IR.

        Returns
        -------
        float
            the apparent magnitude in the IR

        """
        # FIXME: I could grab some template SEDs and K-correct explicitly, accounting for band throughput
        # for IR WF F140W. Should I do this?
        dist_mod = self.cosmo.distmod(z_lens).value
        # FIXME: Enter good model for dust?
        A_V = 0.0 # V-band dust attenuation along LOS
        apmag = M_lens + dist_mod - A_V
        return apmag

    def get_lens_size(self, vel_disp, z_lens, m_V):
        """Get the lens V-band efefctive radius from the Fundamental Plane relation
        given the realized velocity dispersion and apparent magnitude, with some scatter

        Parameters
        ----------
        vel_disp : float
            the velocity dispersion in km/s
        z_lens : float
            redshift
        m_V : float
            V-band apparent magnitude

        Returns
        -------
        tuple
            the effective radius in kpc and arcsec

        """
        R_eff = self.lens_light_size_model(vel_disp, m_V) # in kpc
        r_eff = R_eff * self.cosmo.arcsec_per_kpc_comoving(z_lens).value # in arcsec
        return R_eff, r_eff

    def get_src_absolute_magnitude(self, z_src):
        """Sample the UV absolute magnitude from the luminosity function for the given redshift
        and convert into apparent magnitude

        Parameters
        ----------
        z_src : float
            the source redshift

        Returns
        -------
        float
            the absolute magnitude at 1500A

        """
        M_grid = np.arange(-23.0, -17.8, 0.2)
        nM_dM1500 = self.src_luminosity_model(z_src, M_grid)
        nM_dM1500_normed = nM_dM1500/np.sum(nM_dM1500)
        M1500_src = np.random.choice(M_grid, None, replace=True, p=nM_dM1500_normed)
        return M1500_src

    def get_src_apparent_magnitude(self, M_src, z_src):
        """Convert the souce absolute magnitude into apparent magnitude

        Parameters
        ----------
        M_src : float
            the source absolute magnitude
        z_src : float
            the source redshift

        Note
        ----
        Does not account for peculiar velocity or dust. K-correction is approximate and implicit,
        as the absolute magnitude is at 150nm and, for z ~ 5-9, this portion 
        of the SED roughly lands in the IR.

        Returns
        -------
        float
            the apparent magnitude in the IR

        """
        dust = 0.0
        dist_mod = self.cosmo.distmod(z_src).value
        m_src = M_src + dist_mod - dust
        return m_src

    def get_src_size(self, z_src, M_V_src):
        """Get the effective radius of the source from its empirical relation with V-band absolute
        magnitude and redshift

        Parameters
        ----------
        M_V_src : float
            V-band absolute magnitude of the source
        z_src : float
            source redshift

        Returns
        -------
        tuple
            tuple of the effective radius in kpc and arcsec

        """
        R_eff = self.src_light_size_model(z_src, M_V_src)
        r_eff = R_eff * self.cosmo.arcsec_per_kpc_comoving(z_src).value # in arcsec
        return R_eff, r_eff

    def get_agn_absolute_magnitude(self, z_src):
        """Get the AGN absolute magnitude at 1450A, sampled from the luminosity function for its redshift bin

        Parameters
        ----------
        z_src : float
            the AGN redshift

        Returns
        -------
        float
            AGN absolute magnitude at 1450A

        """
        M_agn = self.agn_luminosity_model(z_src)
        return M_agn

    def sample(self):
        """Gets kwargs of sampled parameters to be passed to lenstronomy

        Returns
        -------
        dict
            dictionary of config-specified components (e.g. lens mass), itself
            a dictionary of sampled parameters corresponding to the config-specified
            profile of that component

        """
        kwargs = Dict()
        # Sample redshifts
        z_lens, z_src = self.sample_redshifts(redshifts_cfg=self.redshift)
        # Sample velocity dispersion
        vel_disp_iso = self.sample_vel_disp(vel_disp_cfg=self.kinematics.vel_disp)
        # Sample lens_mass and lens_light parameters
        abmag_lens = self.get_lens_absolute_magnitude(vel_disp_iso)
        apmag_lens = self.get_lens_apparent_magnitude(abmag_lens, z_lens)
        theta_E = self.theta_E_model(vel_disp_iso, z_lens, z_src, self.cosmo)
        R_eff_lens, r_eff_lens = self.get_lens_size(vel_disp_iso, z_lens, apmag_lens)
        gamma = self.gamma_model(R_eff_lens)
        kwargs['lens_mass'] = dict(
                                   theta_E=theta_E,
                                   gamma=gamma,
                                   q=self.lens_axis_ratio_model(vel_disp_iso),
                                   )
        kwargs['lens_light'] = dict(
                                    magnitude=apmag_lens,
                                    R_sersic=r_eff_lens,
                                    q=self.lens_axis_ratio_model(vel_disp_iso),
                                    )
        kwargs['external_shear'] = {}

        # Sample src_light parameters
        abmag_src = self.get_src_absolute_magnitude(z_src)
        apmag_src = self.get_src_apparent_magnitude(abmag_src, z_src)
        R_eff_src, r_eff_src = self.get_src_size(z_src, abmag_src)
        kwargs['src_light'] = dict(
                                   magnitude=apmag_src,
                                   R_sersic=r_eff_src,
                                   )

        # Sample AGN_light parameters
        if 'agn_light' in self.components:
            abmag_agn = self.get_agn_absolute_magnitude(z_src)
            apmag_agn = self.get_src_apparent_magnitude(abmag_agn, z_src) 
            kwargs['agn_light'] = dict(
                                       magnitude=apmag_agn,
                                       )

        # Miscellaneous other parameters to export
        kwargs['misc'] = dict(
                              z_lens=z_lens,
                              z_src=z_src,
                              vel_disp_iso=vel_disp_iso,
                              lens_light_R_eff=R_eff_lens,
                              src_light_R_eff=R_eff_src,
                              lens_light_abmag=abmag_lens,
                              src_light_abmag=abmag_src,
                              )

        # Sample remaining parameters, not constrained by the above empirical relations,
        # independently from their (marginally) diagonal BNN prior
        for comp, param_name in self.params_to_realize:
            hyperparams = getattr(self, comp)[param_name].copy()
            kwargs[comp][param_name] = self.sample_param(hyperparams)

        # Convert any q, phi into e1, e2 as required by lenstronomy
        for comp in self.comps_qphi_to_e1e2: # e.g. 'lens_mass'
            q = kwargs[comp].pop('q')
            phi = kwargs[comp].pop('phi')
            e1, e2 = param_util.phi_q2_ellipticity(phi, q)
            kwargs[comp]['e1'] = e1
            kwargs[comp]['e2'] = e2

        # Source pos is defined wrt the lens pos
        kwargs['src_light']['center_x'] += kwargs['lens_mass']['center_x']
        kwargs['src_light']['center_y'] += kwargs['lens_mass']['center_y']

        # Ext shear is defined wrt the lens center
        kwargs['external_shear']['ra_0'] = kwargs['lens_mass']['center_x']
        kwargs['external_shear']['dec_0'] = kwargs['lens_mass']['center_y']

        if 'lens_light' in self.components:
            # Lens light shares center with lens mass
            kwargs['lens_light']['center_x'] = kwargs['lens_mass']['center_x']
            kwargs['lens_light']['center_y'] = kwargs['lens_mass']['center_y']

        return kwargs
back to top