https://github.com/RadioAstronomySoftwareGroup/pyuvdata
Raw File
Tip revision: f9cbf410113a1077ff84769beafddb7bf30f44b4 authored by Bryna Hazelton on 29 March 2024, 15:47:47 UTC
address some of the review comments
Tip revision: f9cbf41
test_initializers.py
# -*- mode: python; coding: utf-8 -*-
# Copyright (c) 2024 Radio Astronomy Software Group
# Licensed under the 2-clause BSD License

"""Tests of in-memory initialization of UVData objects."""
from __future__ import annotations

import copy
from typing import Any

import numpy as np
import pytest
from astropy.coordinates import EarthLocation

from pyuvdata import UVData
from pyuvdata.tests.test_utils import selenoids
from pyuvdata.utils import polnum2str
from pyuvdata.uvdata.initializers import (
    configure_blt_rectangularity,
    get_antenna_params,
    get_freq_params,
    get_spw_params,
    get_time_params,
)


@pytest.fixture(scope="function")
def simplest_working_params() -> dict[str, Any]:
    return {
        "freq_array": np.linspace(1e8, 2e8, 100),
        "polarization_array": ["xx", "yy"],
        "antenna_positions": {
            0: [0.0, 0.0, 0.0],
            1: [0.0, 0.0, 1.0],
            2: [0.0, 0.0, 2.0],
        },
        "telescope_location": EarthLocation.from_geodetic(0, 0, 0),
        "telescope_name": "test",
        "times": np.linspace(2459855, 2459856, 20),
    }


@pytest.fixture
def lunar_simple_params() -> dict[str, Any]:
    pytest.importorskip("lunarsky")
    from pyuvdata.utils import MoonLocation

    return {
        "freq_array": np.linspace(1e8, 2e8, 100),
        "polarization_array": ["xx", "yy"],
        "antenna_positions": {
            0: [0.0, 0.0, 0.0],
            1: [0.0, 0.0, 1.0],
            2: [0.0, 0.0, 2.0],
        },
        "telescope_location": MoonLocation.from_selenodetic(0, 0, 0),
        "telescope_name": "test",
        "times": np.linspace(2459855, 2459856, 20),
    }


def test_simplest_new_uvdata(simplest_working_params: dict[str, Any]):
    uvd = UVData.new(**simplest_working_params)

    assert uvd.Nfreqs == 100
    assert uvd.Npols == 2
    assert uvd.Nants_data == 3
    assert uvd.Nbls == 6
    assert uvd.Ntimes == 20
    assert uvd.Nblts == 120
    assert uvd.Nspws == 1


@pytest.mark.parametrize("selenoid", selenoids)
def test_lunar_simple_new_uvdata(lunar_simple_params: dict[str, Any], selenoid: str):
    uvd = UVData.new(**lunar_simple_params, ellipsoid=selenoid)

    assert uvd.telescope._location.frame == "mcmf"
    assert uvd.telescope._location.ellipsoid == selenoid


def test_bad_inputs(simplest_working_params: dict[str, Any]):
    with pytest.raises(ValueError, match="vis_units must be one of"):
        UVData.new(**simplest_working_params, vis_units="foo")

    with pytest.raises(
        ValueError, match="Keyword argument derp is not a valid UVData attribute"
    ):
        UVData.new(**simplest_working_params, derp="foo")


def test_bad_antenna_inputs(simplest_working_params: dict[str, Any]):
    badp = {
        k: v for k, v in simplest_working_params.items() if k != "antenna_positions"
    }
    with pytest.raises(
        ValueError, match="Either antenna_numbers or antenna_names must be provided"
    ):
        UVData.new(
            antenna_positions=np.array([[0, 0, 0], [0, 0, 1], [0, 0, 2]]),
            antenna_numbers=None,
            antenna_names=None,
            **badp,
        )

    badp = {
        k: v for k, v in simplest_working_params.items() if k != "antenna_positions"
    }
    with pytest.raises(
        ValueError,
        match=(
            "antenna_positions must be a dictionary with keys that are all type int "
            "or all type str"
        ),
    ):
        UVData.new(antenna_positions={1: [0, 1, 2], "2": [3, 4, 5]}, **badp)

    badp = {
        k: v for k, v in simplest_working_params.items() if k != "antenna_positions"
    }
    with pytest.raises(ValueError, match="Antenna names must be integers"):
        UVData.new(
            antenna_positions=np.array([[0, 0, 0], [0, 0, 1], [0, 0, 2]]),
            antenna_numbers=None,
            antenna_names=["foo", "bar", "baz"],
            **badp,
        )

    badp = {
        k: v for k, v in simplest_working_params.items() if k != "antenna_positions"
    }
    with pytest.raises(ValueError, match="antenna_positions must be a numpy array"):
        UVData.new(
            antenna_positions="foo",
            antenna_numbers=[0, 1, 2],
            antenna_names=["foo", "bar", "baz"],
            **badp,
        )

    badp = {
        k: v for k, v in simplest_working_params.items() if k != "antenna_positions"
    }
    with pytest.raises(ValueError, match="antenna_positions must be a 2D array"):
        UVData.new(
            antenna_positions=np.array([0, 0, 0]), antenna_numbers=np.array([0]), **badp
        )

    with pytest.raises(ValueError, match="Duplicate antenna names found"):
        UVData.new(antenna_names=["foo", "bar", "foo"], **simplest_working_params)

    badp = {
        k: v for k, v in simplest_working_params.items() if k != "antenna_positions"
    }
    with pytest.raises(ValueError, match="Duplicate antenna numbers found"):
        UVData.new(
            antenna_positions=np.array([[0, 0, 0], [0, 0, 1], [0, 0, 2]]),
            antenna_numbers=[0, 1, 0],
            antenna_names=["foo", "bar", "baz"],
            **badp,
        )

    with pytest.raises(
        ValueError, match="antenna_numbers and antenna_names must have the same length"
    ):
        UVData.new(antenna_names=["foo", "bar"], **simplest_working_params)


def test_bad_time_inputs(simplest_working_params: dict[str, Any]):
    with pytest.raises(ValueError, match="time_array must be a numpy array"):
        get_time_params(
            telescope_location=simplest_working_params["telescope_location"],
            time_array="hello this is a string",
        )

    with pytest.raises(
        TypeError, match="integration_time must be array_like of floats"
    ):
        get_time_params(
            telescope_location=simplest_working_params["telescope_location"],
            integration_time={"a": "dict"},
            time_array=simplest_working_params["times"],
        )

    with pytest.raises(
        ValueError, match="integration_time must be the same shape as time_array"
    ):
        get_time_params(
            integration_time=np.ones(len(simplest_working_params["times"]) + 1),
            telescope_location=simplest_working_params["telescope_location"],
            time_array=simplest_working_params["times"],
        )


def test_bad_freq_inputs(simplest_working_params: dict[str, Any]):
    badp = {k: v for k, v in simplest_working_params.items() if k != "freq_array"}
    with pytest.raises(ValueError, match="freq_array must be a numpy array"):
        UVData.new(freq_array="hello this is a string", **badp)

    badp = {k: v for k, v in simplest_working_params.items() if k != "channel_width"}
    with pytest.raises(TypeError, match="channel_width must be array_like of floats"):
        UVData.new(channel_width={"a": "dict"}, **badp)

    badp = {k: v for k, v in simplest_working_params.items() if k != "channel_width"}
    with pytest.raises(
        ValueError, match="channel_width must be the same shape as freq_array"
    ):
        UVData.new(
            channel_width=np.ones(len(simplest_working_params["freq_array"]) + 1),
            **badp,
        )


def test_bad_rectangularity_inputs():
    with pytest.raises(
        ValueError,
        match="If times and antpairs differ in length, times must all be unique",
    ):
        configure_blt_rectangularity(
            times=np.array([0, 1, 2, 3, 3]),
            antpairs=np.array([(0, 1), (0, 2), (1, 2), (0, 1)]),
        )

    with pytest.raises(
        ValueError,
        match="If times and antpairs differ in length, antpairs must all be unique",
    ):
        configure_blt_rectangularity(
            times=np.array([0, 1, 2, 3, 4]),
            antpairs=np.array([(0, 1), (0, 2), (1, 2), (1, 2)]),
        )

    with pytest.raises(ValueError, match="It is impossible to determine"):
        configure_blt_rectangularity(
            times=np.array([0, 1, 2, 3]),
            antpairs=np.array([(0, 1), (0, 2), (1, 2), (0, 3)]),
        )

    with pytest.raises(
        ValueError, match="times must be unique if do_blt_outer is True"
    ):
        configure_blt_rectangularity(
            times=np.array([0, 1, 2, 3, 3]),
            antpairs=np.array([(0, 1), (0, 2), (1, 2), (0, 1), (0, 2)]),
            do_blt_outer=True,
        )

    with pytest.raises(
        ValueError, match="antpairs must be unique if do_blt_outer is True"
    ):
        configure_blt_rectangularity(
            times=np.array([0, 1, 2, 3, 4]),
            antpairs=np.array([(0, 1), (0, 2), (1, 2), (0, 1), (0, 1)]),
            do_blt_outer=True,
        )

    with pytest.raises(
        ValueError, match="blts_are_rectangular is True, but times and antpairs"
    ):
        configure_blt_rectangularity(
            times=np.array([0, 1, 2]),
            antpairs=np.array([(0, 1), (0, 2), (1, 2), (0, 2)]),
            blts_are_rectangular=True,
            do_blt_outer=False,
        )


def test_alternate_antenna_inputs():
    antpos_dict = {
        0: np.array([0.0, 0.0, 0.0]),
        1: np.array([0.0, 0.0, 1.0]),
        2: np.array([0.0, 0.0, 2.0]),
    }

    antpos_array = np.array([[0, 0, 0], [0, 0, 1], [0, 0, 2]], dtype=float)
    antnum = np.array([0, 1, 2])
    antname = np.array(["000", "001", "002"])

    pos, names, nums = get_antenna_params(antenna_positions=antpos_dict)
    pos2, names2, nums2 = get_antenna_params(
        antenna_positions=antpos_array, antenna_numbers=antnum, antenna_names=antname
    )

    assert np.allclose(pos, pos2)
    assert np.all(names == names2)
    assert np.all(nums == nums2)

    antpos_dict = {
        "000": np.array([0, 0, 0]),
        "001": np.array([0, 0, 1]),
        "002": np.array([0, 0, 2]),
    }
    pos, names, nums = get_antenna_params(antenna_positions=antpos_dict)
    assert np.allclose(pos, pos2)
    assert np.all(names == names2)
    assert np.all(nums == nums2)


def test_alternate_time_inputs():
    loc = EarthLocation.from_geodetic(0, 0, 0)

    time_array = np.linspace(2459855, 2459856, 20)
    integration_time = (time_array[1] - time_array[0]) * 24 * 60 * 60

    times, ints = get_time_params(
        time_array=time_array, integration_time=integration_time, telescope_location=loc
    )
    times2, ints2 = get_time_params(
        time_array=time_array,
        integration_time=integration_time * np.ones_like(time_array),
        telescope_location=loc,
    )
    assert np.allclose(times, times2)
    assert np.allclose(ints, ints2)

    times3, ints3 = get_time_params(time_array=time_array, telescope_location=loc)
    assert np.allclose(times, times3)
    assert np.allclose(ints, ints3)

    # Single time
    with pytest.warns(
        UserWarning, match="integration_time not provided, and cannot be inferred"
    ):
        _, ints4 = get_time_params(time_array=time_array[:1], telescope_location=loc)
    assert np.allclose(ints4, 1.0)


def test_alternate_freq_inputs():
    freq_array = np.linspace(1e8, 2e8, 15)
    channel_width = freq_array[1] - freq_array[0]

    freqs, widths = get_freq_params(freq_array, channel_width=channel_width)

    freqs2, widths2 = get_freq_params(
        freq_array=freq_array, channel_width=channel_width * np.ones_like(freq_array)
    )
    assert np.allclose(freqs, freqs2)
    assert np.allclose(widths, widths2)

    freqs3, widths3 = get_freq_params(freq_array=freq_array)
    assert np.allclose(freqs, freqs3)
    assert np.allclose(widths, widths3)

    # Single frequency
    with pytest.warns(
        UserWarning, match="channel_width not provided, and cannot be inferred"
    ):
        _, widths4 = get_freq_params(freq_array=freq_array[:1])
    assert np.allclose(widths4, 1.0)


def test_empty(simplest_working_params: dict[str, Any]):
    uvd = UVData.new(empty=True, **simplest_working_params)

    assert uvd.data_array.shape == (uvd.Nblts, uvd.Nfreqs, uvd.Npols)
    assert uvd.flag_array.shape == uvd.data_array.shape == uvd.nsample_array.shape
    assert not np.any(uvd.flag_array)
    assert np.all(uvd.nsample_array == 1)
    assert np.all(uvd.data_array == 0)


def test_passing_data(simplest_working_params: dict[str, Any]):
    uvd = UVData.new(empty=True, **simplest_working_params)
    shape = uvd.data_array.shape

    uvd = UVData.new(
        data_array=np.zeros(shape, dtype=complex), **simplest_working_params
    )

    assert np.all(uvd.data_array == 0)
    assert np.all(uvd.flag_array == 0)
    assert np.all(uvd.nsample_array == 1)

    uvd = UVData.new(
        data_array=np.zeros(shape, dtype=complex),
        flag_array=np.ones(shape, dtype=bool),
        **simplest_working_params,
    )

    assert np.all(uvd.data_array == 0)
    assert np.all(uvd.flag_array)
    assert np.all(uvd.nsample_array == 1)

    uvd = UVData.new(
        data_array=np.zeros(shape, dtype=complex),
        flag_array=np.ones(shape, dtype=bool),
        nsample_array=np.ones(shape, dtype=float),
        **simplest_working_params,
    )

    assert np.all(uvd.data_array == 0)
    assert np.all(uvd.flag_array)
    assert np.all(uvd.nsample_array == 1)


def test_passing_bad_data(simplest_working_params: dict[str, Any]):
    uvd = UVData.new(empty=True, **simplest_working_params)
    shape = uvd.data_array.shape

    with pytest.raises(ValueError, match="Data array shape"):
        uvd = UVData.new(
            data_array=np.zeros((1, 2, 3), dtype=float), **simplest_working_params
        )

    with pytest.raises(ValueError, match="Flag array shape"):
        uvd = UVData.new(
            data_array=np.zeros(shape, dtype=complex),
            flag_array=np.ones((1, 2, 3), dtype=float),
            **simplest_working_params,
        )

    with pytest.raises(ValueError, match="nsample array shape"):
        uvd = UVData.new(
            data_array=np.zeros(shape, dtype=complex),
            flag_array=np.ones(shape, dtype=bool),
            nsample_array=np.ones((1, 2, 3), dtype=float),
            **simplest_working_params,
        )


def test_passing_kwargs(simplest_working_params: dict[str, Any]):
    uvd = UVData.new(blt_order=("time", "baseline"), **simplest_working_params)

    assert uvd.blt_order == ("time", "baseline")


def test_blt_rect():
    utimes = np.linspace(2459855, 2459856, 20)
    uaps = np.array([[1, 1], [1, 2], [2, 3]])

    nbls, ntimes, rect, axis, times, bls, _ = configure_blt_rectangularity(
        times=utimes, antpairs=uaps, time_axis_faster_than_bls=False
    )

    assert nbls == 3
    assert ntimes == 20
    assert rect
    assert not axis
    assert len(times) == len(bls)
    assert times[1] == times[0]

    nbls, ntimes, rect, axis, times, bls, _ = configure_blt_rectangularity(
        times=utimes, antpairs=uaps, time_axis_faster_than_bls=True
    )

    assert nbls == 3
    assert ntimes == 20
    assert rect
    assert axis
    assert len(times) == len(bls)
    assert times[1] != times[0]

    TIMES = np.repeat(utimes, len(uaps))
    BLS = np.tile(uaps, (len(utimes), 1))

    nbls, ntimes, rect, axis, times, bls, _ = configure_blt_rectangularity(
        times=TIMES, antpairs=BLS, blts_are_rectangular=True
    )

    assert nbls == 3
    assert ntimes == 20
    assert rect
    assert not axis
    assert len(times) == len(bls)
    assert times[1] == times[0]

    TIMES = np.tile(utimes, len(uaps))
    BLS = np.repeat(uaps, len(utimes), axis=0)

    nbls, ntimes, rect, axis, times, bls, _ = configure_blt_rectangularity(
        times=TIMES, antpairs=BLS, blts_are_rectangular=True
    )

    assert nbls == 3
    assert ntimes == 20
    assert rect
    assert axis
    assert len(times) == len(bls)
    assert times[1] != times[0]

    nbls, ntimes, rect, axis, times, bls, _ = configure_blt_rectangularity(
        times=TIMES, antpairs=BLS, blts_are_rectangular=False
    )

    assert nbls == 3
    assert ntimes == 20
    assert not rect
    assert not axis
    assert len(times) == len(bls)
    assert times[1] != times[0]

    nbls, ntimes, rect, axis, times, bls, _ = configure_blt_rectangularity(
        times=TIMES, antpairs=BLS
    )

    assert nbls == 3
    assert ntimes == 20
    assert rect
    assert axis
    assert len(times) == len(bls)
    assert times[1] != times[0]


def test_set_phase_params(simplest_working_params):
    obj = UVData.new(**simplest_working_params)

    pcc = obj.phase_center_catalog

    new = UVData.new(phase_center_catalog=pcc, **simplest_working_params)
    assert new.phase_center_catalog == pcc

    pccnew = copy.deepcopy(pcc)
    pccnew[1] = copy.deepcopy(pccnew[0])
    pccnew[1]["cat_type"] = "driftscan"
    pccnew[1]["cat_name"] = "another_unprojected"

    with pytest.raises(
        ValueError,
        match=(
            "If phase_center_catalog has more than one key, "
            "phase_center_id_array must be provided"
        ),
    ):
        UVData.new(phase_center_catalog=pccnew, **simplest_working_params)


def test_get_spw_params():
    idarray = np.array([0, 0, 0, 0, 0])
    freq = np.linspace(0, 1, 5)

    _id, spw = get_spw_params(flex_spw_id_array=idarray, freq_array=freq)
    assert np.all(spw == 0)

    idarray = np.array([0, 0, 0, 0, 1])
    _id, spw = get_spw_params(flex_spw_id_array=idarray, freq_array=freq)
    assert np.all(spw == [0, 1])

    with pytest.raises(
        ValueError,
        match="If spw_array has more than one entry, flex_spw_id_array must be",
    ):
        get_spw_params(spw_array=np.array([0, 1]))

    _id, spw = get_spw_params(spw_array=np.array([1]), freq_array=np.zeros(10))
    assert len(_id) == 10
    assert len(spw) == 1
    assert np.all(_id) == 1

    # Passing both spw_array and flex_spws, but getting them right
    _id, spw = get_spw_params(
        spw_array=np.array([0, 1]),
        flex_spw_id_array=np.concatenate(
            (np.zeros(10, dtype=int), np.ones(10, dtype=int))
        ),
    )
    assert len(spw) == 2

    # Passing both spw_array and flex_spws, but getting them wrong
    with pytest.raises(
        ValueError,
        match="spw_array and flex_spw_id_array must have the same number of unique",
    ):
        _id, spw = get_spw_params(
            spw_array=np.array([0, 1]), flex_spw_id_array=np.zeros(10, dtype=int)
        )


@pytest.mark.parametrize("xorient", ["e", "n", "east", "NORTH"])
def test_passing_xorient(simplest_working_params, xorient):
    uvd = UVData.new(x_orientation=xorient, **simplest_working_params)
    if xorient.lower().startswith("e"):
        assert uvd.x_orientation == "east"
    else:
        assert uvd.x_orientation == "north"


def test_passing_directional_pols(simplest_working_params):
    kw = {**simplest_working_params, **{"polarization_array": ["ee"]}}

    with pytest.raises(KeyError, match="'ee'"):
        UVData.new(**kw)

    uvd = UVData.new(x_orientation="east", **kw)
    assert polnum2str(uvd.polarization_array[0], x_orientation="east") == "ee"
back to top