Raw File
image_utils.py
import copy
import sys
import numpy as np
# Lenstronomy modules
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.LightModel.light_model import LightModel
from lenstronomy.ImSim.image_model import ImageModel
from baobab.sim_utils import mag_to_amp_extended, mag_to_amp_point, get_lensed_total_flux, get_unlensed_total_flux_numerical
from lenstronomy.LensModel.Solver.lens_equation_solver import LensEquationSolver
from lenstronomy.SimulationAPI.data_api import DataAPI
from lenstronomy.PointSource.point_source import PointSource

from baobab.sim_utils import psf_utils


__all__ = ['Imager', 'Imager2']

class Imager2:
    """Dev-mode class, more flexible than Imager.

    Note
    ----
    Accompanying `generate` script doesn't exist yet.

    """
    def __init__(self, lens_model_list, src_model_list, 
                 n_pix, pixel_scale,
                 psf_type, psf_kernel_size=None, which_psf_maps=None,
                 kwargs_numerics={'supersampling_factor': 1}):
        # Define models
        self.lens_model = LensModel(lens_model_list=lens_model_list)
        self.src_model = LightModel(light_model_list=src_model_list)
        #self.ps_model = ps_model
        #self.lens_light_model = lens_light_model
        # Set detector specs
        self.n_pix = n_pix
        self.pixel_scale = pixel_scale
        self.psf_type = psf_type
        self.psf_kernel_size = psf_kernel_size
        self.which_psf_maps = which_psf_maps
        self.kwargs_numerics = kwargs_numerics
        # Initialize kwargs (must be set using setter)
        self._survey = None
        self._lens_kwargs = None
        self._src_kwargs = None
        #self._ps_kwargs = None
        #self._lens_light_kwargs = None

    @property
    def survey_kwargs(self):
        """Ordered dict containing detector information. Length is number of 
        bandpasses. Should be set before the model kwargs.

        """
        return self._survey_kwargs

    @survey_kwargs.setter
    def survey_kwargs(self, survey_kwargs):
        survey_name = survey_kwargs['survey_name']
        bandpass_list = survey_kwargs['bandpass_list']
        coadd_years = survey_kwargs.get('coadd_years')
        override_obs_kwargs = survey_kwargs.get('override_obs_kwargs', {})
        override_camera_kwargs = survey_kwargs.get('override_camera_kwargs', {})

        import lenstronomy.SimulationAPI.ObservationConfig as ObsConfig
        from importlib import import_module
        sys.path.insert(0, ObsConfig.__path__[0])
        SurveyClass = getattr(import_module(survey_name), survey_name)
        self._data_api = [] # init
        self._image_model = [] # init
        for bp in bandpass_list:
            survey_obj = SurveyClass(band=bp, 
                                     psf_type=self.psf_type, 
                                     coadd_years=coadd_years)
            # Override as specified in survey_kwargs
            survey_obj.camera.update(override_camera_kwargs)
            survey_obj.obs.update(override_obs_kwargs)
            # This is what we'll actually use
            kwargs_detector = survey_obj.kwargs_single_band()
            data_api = DataAPI(self.n_pix, **kwargs_detector)
            psf_model = psf_utils.get_PSF_model(self.psf_type, 
                                                self.pixel_scale, 
                                                seeing=kwargs_detector['seeing'], 
                                                kernel_size=self.psf_kernel_size, 
                                                which_psf_maps=self.which_psf_maps)
            image_model_bp = ImageModel(data_api.data_class, 
                                        psf_model, 
                                        self.lens_model, 
                                        self.src_model, 
                                        None, 
                                        None, 
                                        kwargs_numerics=self.kwargs_numerics)
            self._data_api.append(data_api)
            self._image_model.append(image_model_bp)

    @property
    def lens_kwargs(self):
        return self._lens_kwargs

    @lens_kwargs.setter
    def lens_kwargs(self, lens_kwargs):
        self._lens_kwargs = lens_kwargs

    @property
    def src_kwargs(self):
        return self._src_kwargs

    @src_kwargs.setter
    def src_kwargs(self, src_kwargs):
        for i, data_api_bp in enumerate(self._data_api):
            # Convert magnitude to amp recognized by the profile
            if 'magnitude' in src_kwargs[i]:
                src_kwargs[i] = mag_to_amp_extended([src_kwargs[i]], 
                                                    self.src_model, 
                                                    data_api_bp)[0]
        self._src_kwargs = src_kwargs

    def generate_image(self):
        n_filters = len(self._image_model)
        img_canvas = np.empty([n_filters, self.n_pix, self.n_pix])
        for i, image_model_bp in enumerate(self._image_model):
            img = image_model_bp.image(self.lens_kwargs, 
                                       self.src_kwargs, 
                                       None, None,
                                       lens_light_add=False, 
                                       point_source_add=False)
            img = np.maximum(0.0, img) # safeguard against negative pixel values
            img_canvas[i, :, :] = img 
        return img_canvas

class Imager:
    """Deterministic utility class for imaging the objects on a pixel grid

        Attributes
        ----------
        bnn_omega : dict
            copy of `cfg.bnn_omega`
        components : list
            list of components, e.g. `lens_mass`

        """
    def __init__(self, components, lens_mass_model, src_light_model, lens_light_model=None, ps_model=None, kwargs_numerics={'supersampling_factor': 1}, min_magnification=0.0, for_cosmography=False, magnification_frac_err=0.0):
        self.components = components
        self.kwargs_numerics = kwargs_numerics
        self.lens_mass_model = lens_mass_model
        self.src_light_model = src_light_model
        self.lens_light_model = lens_light_model
        self.ps_model = ps_model
        self.unlensed_ps_model = PointSource(point_source_type_list=['SOURCE_POSITION'], fixed_magnification_list=[False])
        self.lens_eq_solver = LensEquationSolver(self.lens_mass_model)
        self.min_magnification = min_magnification
        self.for_cosmography = for_cosmography
        self.magnification_frac_err = magnification_frac_err
        self.img_features = {} # Initialized to store metadata of images, will get updated for each lens

    def _set_sim_api(self, num_pix, kwargs_detector, psf_kernel_size, which_psf_maps):
        """Set the simulation API objects

        """
        self.data_api = DataAPI(num_pix, **kwargs_detector)
        #self.pixel_scale = data_api.pixel_scale
        pixel_scale = kwargs_detector['pixel_scale']
        psf_model = psf_utils.get_PSF_model(kwargs_detector['psf_type'], pixel_scale, seeing=kwargs_detector['seeing'], kernel_size=psf_kernel_size, which_psf_maps=which_psf_maps)
        # Set the precision level of lens equation solver
        self.min_distance = 0.05
        self.search_window = pixel_scale*num_pix
        self.image_model = ImageModel(self.data_api.data_class, psf_model, self.lens_mass_model, self.src_light_model, self.lens_light_model, self.ps_model, kwargs_numerics=self.kwargs_numerics)
        if 'agn_light' in self.components:
            self.unlensed_image_model = ImageModel(self.data_api.data_class, psf_model, None, self.src_light_model, None, self.unlensed_ps_model, kwargs_numerics=self.kwargs_numerics)
        else:
            self.unlensed_image_model = ImageModel(self.data_api.data_class, psf_model, None, self.src_light_model, None, None, kwargs_numerics=self.kwargs_numerics)

    def _load_kwargs(self, sample):
        """Generate an image from provided model and model parameters

        Parameters
        ----------
        sample : dict
            model parameters sampled by a bnn_prior object

        """
        self._load_lens_mass_kwargs(sample['lens_mass'], sample['external_shear'])
        self._load_src_light_kwargs(sample['src_light'])
        if 'lens_light' in self.components:
            self._load_lens_light_kwargs(sample['lens_light'])
        else:
            self.kwargs_lens_light = None
        if 'agn_light' in self.components:
            self._load_agn_light_kwargs(sample)
        else:
            self.kwargs_ps = None
            self.kwargs_unlensed_unmagnified_amp_ps = None

    def _load_lens_mass_kwargs(self, lens_mass_sample, external_shear_sample):
        self.kwargs_lens_mass = [lens_mass_sample, external_shear_sample]

    def _load_src_light_kwargs(self, src_light_sample):
        kwargs_src_light = [src_light_sample]
        # Convert from mag to amp
        self.kwargs_src_light = mag_to_amp_extended(kwargs_src_light, self.src_light_model, self.data_api)

    def _load_lens_light_kwargs(self, lens_light_sample):
        kwargs_lens_light = [lens_light_sample]
        # Convert lens magnitude into amp
        self.kwargs_lens_light = mag_to_amp_extended(kwargs_lens_light, self.lens_light_model, self.data_api)

    def _load_agn_light_kwargs(self, sample):
        """Set the point source kwargs to be ingested by Lenstronomy

        """
        # When using the image positions for cosmological parameter recovery, the time delays must be computed by evaluating the Fermat potential at these exact positions.
        if self.for_cosmography:
            x_image = sample['misc']['x_image']
            y_image = sample['misc']['y_image']
        # When the precision of the lens equation solver doesn't have to be matched between image positions and time delays, simply solve for the image positions using whatever desired precision.
        else:
            x_image, y_image = self.lens_eq_solver.findBrightImage(self.kwargs_src_light[0]['center_x'], 
                                                                   self.kwargs_src_light[0]['center_y'],
                                                                   self.kwargs_lens_mass,
                                                                   min_distance=self.min_distance,
                                                                   search_window=self.search_window,
                                                                   numImages=4,
                                                                   num_iter_max=100, # default is 10 but td_cosmography default is 100
                                                                   precision_limit=10**(-10) # default for both this and td_cosmography
                                                                    ) 
        agn_light_sample = sample['agn_light']
        unlensed_mag = agn_light_sample['magnitude'] # unlensed agn mag
        # Save the unlensed (source-plane) kwargs in amplitude units
        kwargs_unlensed_unmagnified_mag_ps = [{'ra_source': self.kwargs_src_light[0]['center_x'], 'dec_source': self.kwargs_src_light[0]['center_y'], 'magnitude': unlensed_mag}]
        self.kwargs_unlensed_unmagnified_amp_ps = mag_to_amp_point(kwargs_unlensed_unmagnified_mag_ps, self.unlensed_ps_model, self.data_api) # note 
        # Compute the lensed (image-plane), magnified kwargs in amplitude units
        magnification = self.lens_mass_model.magnification(x_image, y_image, kwargs=self.kwargs_lens_mass)
        measured_magnification = np.abs(magnification*(1.0 + self.magnification_frac_err*np.random.randn(len(magnification)))) # Add noise to magnification
        magnification = np.abs(magnification)
        kwargs_lensed_unmagnified_mag_ps = [{'ra_image': x_image, 'dec_image': y_image, 'magnitude': unlensed_mag}] # note unlensed magnitude
        kwargs_lensed_unmagnified_amp_ps = mag_to_amp_point(kwargs_lensed_unmagnified_mag_ps, self.ps_model, self.data_api) # note unmagnified amp
        self.kwargs_ps = copy.deepcopy(kwargs_lensed_unmagnified_amp_ps)
        for kw in self.kwargs_ps:
            kw.update(point_amp=kw['point_amp']*measured_magnification)
        # Log the solved image positions
        self.img_features.update(x_image=x_image, 
                                 y_image=y_image, 
                                 magnification=magnification,
                                 measured_magnification=measured_magnification)

    def generate_image(self, sample, num_pix, survey_object_dict):
        img_canvas = np.empty([len(survey_object_dict), num_pix, num_pix]) # [n_filters, num_pix, num_pix]
        # Loop over bands
        for i, (bp, survey_object) in enumerate(survey_object_dict.items()):
            self._set_sim_api(num_pix, survey_object.kwargs_single_band(), survey_object.psf_kernel_size, survey_object.which_psf_maps)
            self._load_kwargs(sample)
            # Reject nonsensical number of images (due to insufficient numerical precision)
            if ('y_image' in self.img_features) and (len(self.img_features['y_image']) not in [2, 4]):
                return None, None
            # Compute magnification
            lensed_total_flux = get_lensed_total_flux(self.kwargs_lens_mass, self.kwargs_src_light, self.kwargs_ps, self.image_model)
            #unlensed_total_flux = get_unlensed_total_flux(self.kwargs_src_light, self.src_light_model, self.kwargs_unlensed_amp_ps, self.ps_model)
            unlensed_total_flux = get_unlensed_total_flux_numerical(self.kwargs_src_light, self.kwargs_unlensed_unmagnified_amp_ps, self.unlensed_image_model)
            total_magnification = lensed_total_flux/unlensed_total_flux
            # Apply magnification cut
            if (total_magnification < self.min_magnification) or np.isnan(total_magnification):
                return None, None
            # Generate image for export
            img = self.image_model.image(self.kwargs_lens_mass, self.kwargs_src_light, self.kwargs_lens_light, self.kwargs_ps)
            img = np.maximum(0.0, img) # safeguard against negative pixel values
            img_canvas[i, :, :] = img 
            # Save remaining image features
            img_features_single_band = {f'total_magnification_{bp}': total_magnification, f'lensed_total_flux_{bp}': lensed_total_flux, f'unlensed_total_flux_{bp}': unlensed_total_flux}
            self.img_features.update(img_features_single_band)
        return img_canvas, self.img_features

    def add_noise(self, image_array):
        """Add noise to the image (deprecated; replaced by the data_augmentation package)

        """
        #noise_map = self.data_api.noise_for_model(image_array, background_noise=True, poisson_noise=True, seed=None)
        #image_array += noise_map
        #return image_array
        pass
back to top