Raw File
base_cosmo_bnn_prior.py
import numpy as np
from astropy.cosmology import FlatLambdaCDM
from abc import ABC, abstractmethod

class BaseCosmoBNNPrior(ABC):
    """Abstract base class for a cosmology-aware BNN prior

    """
    def __init__(self, bnn_omega):
        self._check_cosmology_config_validity(bnn_omega)
        self._define_cosmology(bnn_omega.cosmology)
        for cosmo_comp in ['cosmology', 'redshift', 'kinematics']:
            setattr(self, cosmo_comp, bnn_omega[cosmo_comp])

        self.sample_redshifts = getattr(self, 'sample_redshifts_from_{:s}'.format(self.redshift.model))

    def _raise_config_error(self, missing_key, parent_config_key, bnn_prior_class):
        """Convenience function for raising errors related to config values

        """
        raise ValueError("{:s} must be specified in the config inside {:s} for {:s}".format(missing_key, parent_config_key, bnn_prior_class))

    def _check_cosmology_config_validity(self, bnn_omega):
        """Check whether the config file specified the hyperparameters for all the fields
        required for cosmology-aware BNN priors, e.g. cosmology, redshift, galaxy kinematics

        """
        required_keys = ['cosmology', 'redshift', 'kinematics']
        for possible_missing_key in required_keys:
            if possible_missing_key not in bnn_omega:
                self._raise_cfg_error(possible_missing_key, 'bnn_omega', cls.__name__)

    def _define_cosmology(self, cosmology_cfg):
        """Set the cosmology, with which to generate all the training samples, based on the config

        Parameters
        ----------
        cosmology_cfg : dict
            Copy of `cfg.bnn_omega.cosmology`

        """
        self.cosmo = FlatLambdaCDM(**cosmology_cfg)

    def sample_param(self, hyperparams):
        """Assigns a sampling distribution

        """
        dist = hyperparams.pop('dist')
        return getattr(baobab.distributions, 'sample_{:s}'.format(dist))(**hyperparams)

    def eval_param_pdf(self, eval_at, hyperparams):
        """Assigns and evaluates the PDF 

        """
        dist = hyperparams.pop('dist')
        return getattr(baobab.distributions, 'eval_{:s}_pdf'.format(dist))(**hyperparams)

    def sample_redshifts_from_differential_comoving_volume(self, redshifts_cfg):
        """Sample redshifts from the differential comoving volume,
        on a grid with the range and resolution specified in the config

        Parameters
        ----------
        redshifts_cfg : dict
            Copy of `cfg.bnn_omega.redshift`

        Returns
        -------
        tuple
            the tuple of floats that are the realized z_lens, z_src

        """
        z_grid = np.arange(**redshifts_cfg.grid)
        dVol_dz = self.cosmo.differential_comoving_volume(z_grid).value
        dVol_dz_normed = dVol_dz/np.sum(dVol_dz)
        sampled_z = np.random.choice(z_grid, 2, replace=False, p=dVol_dz_normed)
        z_lens = np.min(sampled_z)
        z_src = np.max(sampled_z)
        while z_src < z_lens + redshifts_cfg.min_diff:
            sampled_z = np.random.choice(z_grid, 2, replace=False, p=dVol_dz_normed)
            z_lens = np.min(sampled_z)
            z_src = np.max(sampled_z)
        return z_lens, z_src

    def sample_redshifts_from_independent_dist(self, redshifts_cfg):
        """Sample lens and source redshifts from independent distributions, while enforcing that the lens redshift is smaller than source redshift

        Parameters
        ----------
        redshifts_cfg : dict
            Copy of `cfg.bnn_omega.redshift`

        Returns
        -------
        tuple
            the tuple of floats that are the realized z_lens, z_src

        """
        z_lens = self.sample_param(redshifts_cfg.z_lens.copy())
        z_src = self.sample_param(redshifts_cfg.z_src.copy())
        while z_src < z_lens + redshifts_cfg.min_diff:
            z_lens = self.sample_param(redshifts_cfg.z_lens.copy())
            z_src = self.sample_param(redshifts_cfg.z_src.copy())
        return z_lens, z_src
back to top