Raw File
test_noise_tf.py
import os
import random
import unittest
import numpy as np
import tensorflow as tf
from baobab.data_augmentation import get_noise_sigma2_lenstronomy
from baobab.data_augmentation.noise_tf import NoiseModelTF
from baobab.tests.test_data_augmentation.tf_data_utils import generate_simple_tf_record, parse_example, tf_img_size, tf_y_names, tf_data_size

class TestNoiseTF(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        """Seed randomness

        """
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
        np.random.seed(123)
        random.seed(123)
        cls.img = np.random.randn(3, 3)*3.0 + 6.0
        cls.noise_kwargs = dict(
                                pixel_scale=0.08,
                                exposure_time=100.0,
                                magnitude_zero_point=25.9463, 
                                read_noise=10, 
                                ccd_gain=7.0,
                                sky_brightness=20.1,
                                seeing=0.6, 
                                num_exposures=1, 
                                psf_type='GAUSSIAN', 
                                kernel_point_source=None, 
                                truncation=5,
                                #data_count_unit='ADU', 
                                background_noise=None
                                )

    def test_lenstronomy_vs_tf_ADU(self):
        """Compare the lenstronomy and tf noise variance for ADU units

        """
        lens_sigma2 = get_noise_sigma2_lenstronomy(self.img, data_count_unit='ADU', **self.noise_kwargs)
        img_tf_tensor = tf.cast(self.img, tf.float32)
        noise_model_tf = NoiseModelTF(**self.noise_kwargs)
        tf_sigma2 = {}
        tf_sigma2['sky'] = noise_model_tf.get_sky_noise_sigma2()
        tf_sigma2['readout'] = noise_model_tf.get_readout_noise_sigma2()
        tf_sigma2['poisson'] = noise_model_tf.get_poisson_noise_sigma2(img_tf_tensor)
        np.testing.assert_array_almost_equal(self.img, img_tf_tensor.numpy(), decimal=5, err_msg="image array")
        np.testing.assert_almost_equal(lens_sigma2['sky'], tf_sigma2['sky'], decimal=7, err_msg="sky")
        np.testing.assert_almost_equal(lens_sigma2['readout'], tf_sigma2['readout'], decimal=7, err_msg="readout")
        np.testing.assert_array_almost_equal(lens_sigma2['poisson'], tf_sigma2['poisson'].numpy(), decimal=7, err_msg="poisson")

    def test_lenstronomy_vs_tf_electron(self):
        """Compare the lenstronomy and tf noise variance for electron units

        """
        lens_sigma2 = get_noise_sigma2_lenstronomy(self.img, data_count_unit='e-', **self.noise_kwargs)
        img_tf_tensor = tf.cast(self.img, tf.float32)
        noise_model_tf = NoiseModelTF(data_count_unit='e-', **self.noise_kwargs)
        tf_sigma2 = {}
        tf_sigma2['sky'] = noise_model_tf.get_sky_noise_sigma2()
        tf_sigma2['readout'] = noise_model_tf.get_readout_noise_sigma2()
        tf_sigma2['poisson'] = noise_model_tf.get_poisson_noise_sigma2(img_tf_tensor)
        np.testing.assert_array_almost_equal(self.img, img_tf_tensor.numpy(), decimal=5, err_msg="image array")
        np.testing.assert_almost_equal(lens_sigma2['sky'], tf_sigma2['sky'], decimal=7, err_msg="sky")
        np.testing.assert_almost_equal(lens_sigma2['readout'], tf_sigma2['readout'], decimal=7, err_msg="readout")
        np.testing.assert_array_almost_equal(lens_sigma2['poisson'], tf_sigma2['poisson'].numpy(), decimal=7, err_msg="poisson")

    def test_build_tf_dataset(self):
        """Test whether tf.data.Dataset can be instantiated from tf.data.TFRecordDataset with the data augmentation (noise addition) mapping

        """
        tf_record_path = os.path.abspath('test_ADU')
        batch_size = 2
        n_epochs = 3

        noise_model_tf = NoiseModelTF(**self.noise_kwargs)
        add_noise_func = getattr(noise_model_tf, 'add_noise')
        #print(add_noise_func(tf.ones((3, 3), dtype=tf.float32)))

        generate_simple_tf_record(tf_record_path, tf_y_names)
        tf_dataset = tf.data.TFRecordDataset(tf_record_path).map(parse_example).map(lambda image, label: (add_noise_func(image), label)).repeat(n_epochs).shuffle(buffer_size=tf_data_size + 1).batch(batch_size, drop_remainder=True)
        
        images = [img for img, label in tf_dataset]
        labels = [label for img, label in tf_dataset]
        size = len(labels)
        np.testing.assert_array_equal(images[0].shape, (batch_size, tf_img_size, tf_img_size, 1))
        np.testing.assert_array_equal(labels[0].shape, (batch_size, len(tf_y_names)))
        np.testing.assert_equal(size, (tf_data_size*n_epochs//2))
        # Delete resulting data
        if os.path.exists(tf_record_path):
            os.remove(tf_record_path)

if __name__ == '__main__':
    unittest.main()
back to top