https://github.com/RadioAstronomySoftwareGroup/pyuvdata
Raw File
Tip revision: 0b1b091eeec941c6081327cfd1d65f689feafd8c authored by Paul La Plante on 08 August 2020, 00:16 UTC
Merge branch 'Smithsonian-submillimeter_array' into master
Tip revision: 0b1b091
__init__.py
# -*- mode: python; coding: utf-8 -*-
# Copyright (c) 2018 Radio Astronomy Software Group
# Licensed under the 2-clause BSD License

"""Setup testing environment, define useful testing functions.

"""
import re
import sys
import warnings

from astropy.utils import iers

import pyuvdata.utils as uvutils

__all__ = [
    "check_warnings",
    "clearWarnings",
    "checkWarnings",
]


class WarningsChecker(warnings.catch_warnings):
    """
    A context manager to check raised warnings.

    Adapted from pytest WarningsRecorder and WarningsChecker.

    Parameters
    ----------
    expected_warning : list of Warnings
        List of expected warnings.
    match : list of str or regex
        List of strings to match warnings to.

    """

    def __init__(self, expected_warning, match):
        """Check inputs and initalize object."""
        super().__init__(record=True)
        self._entered = False
        self._list = []

        msg = "exceptions must be derived from Warning, not %s"
        if expected_warning is None:
            expected_warning_list = expected_warning
        elif isinstance(expected_warning, list):
            for exc in expected_warning:
                if not issubclass(exc, Warning):
                    raise TypeError(msg % type(exc))
            expected_warning_list = expected_warning
        elif issubclass(expected_warning, Warning):
            expected_warning_list = [expected_warning]
        else:
            raise TypeError(msg % type(expected_warning))

        msg = "match must be a str, not %s"
        if match is None:
            match_list = None
        elif isinstance(match, list):
            for exc in match:
                if not isinstance(exc, str):
                    raise TypeError(msg % type(exc))
            match_list = match
        elif isinstance(match, str):
            match_list = [match]
        else:
            raise TypeError(msg % type(match))

        self.expected_warning = expected_warning_list
        self.match = match_list

    @property
    def warnlist(self):
        """The list of recorded warnings."""
        return self._list

    def __getitem__(self, i: int):
        """Get a recorded warning by index."""
        return self._list[i]

    def __iter__(self):
        """Iterate through the recorded warnings."""
        return iter(self._list)

    def __len__(self):
        """The number of recorded warnings."""
        return len(self._list)

    def pop(self, cls):
        """
        Pop the first recorded warning, raise exception if not exists.

        Parameters
        ----------
        cls : Warning
            Warning class to check for.
        """
        for i, w in enumerate(self._list):
            if issubclass(w.category, cls):
                return self._list.pop(i)
        __tracebackhide__ = True
        raise AssertionError("%r not found in warning list" % cls)

    def clear(self) -> None:
        """Clear the list of recorded warnings."""
        self._list[:] = []

    def __enter__(self):
        if self._entered:
            __tracebackhide__ = True
            raise RuntimeError("Cannot enter %r twice" % self)
        _list = super().__enter__()
        # record=True means it's None.
        assert _list is not None
        self._list = _list
        warnings.simplefilter("always")
        # Filter annoying Cython warnings that serve no good purpose. see numpy#432
        warnings.filterwarnings("ignore", message="numpy.dtype size changed")
        warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

        # Filter iers warnings if iers.conf.auto_max_age is set to None, as we
        # do in testing if the iers url is down. See conftest.py for more info.
        if iers.conf.auto_max_age is None:
            warnings.filterwarnings("ignore", message="failed to download")
            warnings.filterwarnings("ignore", message="time is out of IERS range")

            test_message = []
            for message in self.match:
                if message is None:
                    test_message.append(False)
                else:
                    test_message.append(message.startswith("LST values stored in "))
            if not any(test_message):
                warnings.filterwarnings("ignore", message="LST values stored in ")

        return self

    def __exit__(
        self, exc_type=None, exc_val=None, exc_tb=None,
    ):
        if not self._entered:
            __tracebackhide__ = True
            raise RuntimeError("Cannot exit %r without entering first" % self)

        super().__exit__(exc_type, exc_val, exc_tb)

        # Built-in catch_warnings does not reset entered state so we do it
        # manually here for this context manager to become reusable.
        self._entered = False

        # only check if we're not currently handling an exception
        if exc_type is None and exc_val is None and exc_tb is None:
            if self.expected_warning is None:
                assert len(self) == 0
            else:
                assert len(self) == len(self.expected_warning), (
                    f"{len(self.expected_warning)} warnings expected, "
                    f"{len(self)} warnings issued. The list of emitted warnings is: "
                    f"{[each.message for each in self]}."
                )

                for warn_i, exp_warn in enumerate(self.expected_warning):
                    if not any(issubclass(r.category, exp_warn) for r in self):
                        __tracebackhide__ = True
                        raise AssertionError(
                            "DID NOT WARN. No warnings of type {} was emitted. "
                            "The list of emitted warnings is: {}.".format(
                                self.expected_warning, [each.message for each in self]
                            )
                        )
                    elif self.match is not None:
                        for record in self:
                            if str(record.message).startswith(
                                self.match[warn_i]
                            ) or re.compile(self.match[warn_i]).search(
                                str(record.message)
                            ):
                                if issubclass(
                                    record.category, self.expected_warning[warn_i]
                                ):
                                    break
                        else:
                            raise AssertionError(
                                f"No warnings of type {self.expected_warning[warn_i]} "
                                f"matching ('{self.match[warn_i]}') was "
                                "emitted. The list of emitted warnings is: "
                                f"{[each.message for each in self]}."
                            )


def check_warnings(expected_warning, match=None, nwarnings=None, *args, **kwargs):
    """
    Assert that code raises a particular set of warnings, used as a context manager.

    Similar to ``pytest.warns``, but allows for specifying multiple warnings.
    It also better matches warning strings when the warning uses f-strings or
    formating. Can be used as a drop-in replacement for ``pytest.warns`` if
    only one warning is issued (if more are issued they will need to be added
    to the input lists for this to pass).

    Note that unlike the older checkWarnings function, the warnings can be passed
    in any order, they do not have to match the order the warnings are raised
    in the code.

    To assert that there are no warnings raised by some code, set `expected_warning`
    to None (i.e. `with check_warnings(None):`)

    Parameters
    ----------
    expected_warning : list of Warnings or Warning
        List of expected warnings. If a single warning type or a length 1 list,
        will be used for the type of all warnings.
    match : str or regex or list of str or regex
        List of strings or regexes to match warnings to. If a str or a length 1
        list, will be used for all warnings.
    nwarnings : int, optional
        Option to specify that multiple of a single type of warning is expected.
        Only used if category and match both only have one element.

    """
    __tracebackhide__ = True

    if not (
        expected_warning is None
        or isinstance(expected_warning, list)
        or issubclass(expected_warning, Warning)
    ):
        raise TypeError("expected_warning must be a list or be derived from Warning")
    if match is not None and not isinstance(match, (list, str)):
        raise TypeError("match must be a list or a string.")

    if expected_warning is not None and not isinstance(expected_warning, list):
        expected_warning_list = [expected_warning]
    else:
        expected_warning_list = expected_warning
    if match is not None and not isinstance(match, list):
        match_list = [match]
    else:
        match_list = match

    if expected_warning is not None:
        if (
            len(expected_warning_list) > 1
            and len(match_list) > 1
            and len(expected_warning_list) != len(match_list)
        ):
            raise ValueError(
                "If expected_warning and match both have more than one element, "
                "they must be the same length."
            )

        if len(expected_warning_list) > 1 or len(match_list) > 1:
            nwarnings = max(len(expected_warning_list), len(match_list))
        elif nwarnings is None:
            nwarnings = 1

        if len(expected_warning_list) < nwarnings:
            expected_warning_list = expected_warning_list * nwarnings
        if len(match_list) < nwarnings:
            match_list = match_list * nwarnings

    if not args:
        if kwargs:
            msg = "Unexpected keyword arguments passed to check_warnings: "
            msg += ", ".join(sorted(kwargs))
            msg += "\nUse context-manager form instead?"
            raise TypeError(msg)
        return WarningsChecker(expected_warning_list, match_list)
    else:
        func = args[0]
        if not callable(func):
            raise TypeError(
                "{!r} object (type: {}) must be callable".format(func, type(func))
            )
        with WarningsChecker(expected_warning_list, match_list):
            return func(*args[1:], **kwargs)


# Functions that are useful for testing:
def clearWarnings():
    """Quick code to make warnings reproducible."""
    for name, mod in list(sys.modules.items()):
        try:
            reg = getattr(mod, "__warningregistry__", None)
        except ImportError:
            continue
        if reg:
            reg.clear()


def checkWarnings(
    func,
    func_args=None,
    func_kwargs=None,
    nwarnings=1,
    category=UserWarning,
    message=None,
    known_warning=None,
):
    """
    Function to check expected warnings in tests.

    Deprecated. Use check_warnings instead.

    Useful for checking that appropriate warnings are raised and to capture
    (and silence) warnings in tests.

    Parameters
    ----------
    func : function
        Function or method to check warnings for.
    func_args : list, optional
        List of positional parameters to pass `func`
    func_kwargs : dict, optional
        Dict of keyword parameter to pass func. Keys are the parameter names,
        values are the values to pass to the parameters.
    nwarnings : int
        Number of expected warnings.
    category : warning type or list of warning types
        Expected warning type(s). If a scalar is passed and `nwarnings` is
        greater than one, the same category will be expected for all warnings.
    message : str or list of str
        Expected warning string(s). If a scalar is passed and `nwarnings` is
        greater than one, the same warning string will be expected for all warnings.
    known_warning : {'miriad', 'paper_uvfits', 'fhd'}, optional
        Shorthand way to specify one of a standard set of warnings.

    Returns
    -------
    Value returned by `func`

    Raises
    ------
    AssertionError
        If the warning(s) raised by func do not match the expected values.
    """
    warnings.warn(
        "`checkWarnings` is deprecated, and will be removed in pyuvdata version "
        "2.3. Use `check_warnings` instead.",
        DeprecationWarning,
    )
    if func_args is None:
        func_args = []
    if func_kwargs is None:
        func_kwargs = {}

    if (not isinstance(category, list) or len(category) == 1) and nwarnings > 1:
        if isinstance(category, list):
            category = category * nwarnings
        else:
            category = [category] * nwarnings

    if (not isinstance(message, list) or len(message) == 1) and nwarnings > 1:
        if isinstance(message, list):
            message = message * nwarnings
        else:
            message = [message] * nwarnings

    if known_warning == "miriad":
        # The default warnings for known telescopes when reading miriad files
        category = [UserWarning]
        message = [
            "Altitude is not present in Miriad file, using known "
            "location values for PAPER."
        ]
        nwarnings = 1
    elif known_warning == "paper_uvfits":
        # The default warnings for known telescopes when reading uvfits files
        category = [UserWarning] * 2
        message = ["Required Antenna frame keyword", "telescope_location is not set"]
        nwarnings = 2
    elif known_warning == "fhd":
        category = [UserWarning]
        message = ["Telescope location derived from obs"]
        nwarnings = 1

    category = uvutils._get_iterable(category)
    message = uvutils._get_iterable(message)

    clearWarnings()
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")  # All warnings triggered
        warnings.filterwarnings("ignore", message="numpy.dtype size changed")
        warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

        # Filter iers warnings if iers.conf.auto_max_age is set to None, as we
        # do in testing if the iers url is down. See conftest.py for more info.
        if iers.conf.auto_max_age is None:
            warnings.filterwarnings("ignore", message="failed to download")
            warnings.filterwarnings("ignore", message="time is out of IERS range")

            if isinstance(message, str):
                test_message = [message.startswith("LST values stored in ")]
            else:
                test_message = []
                for m in message:
                    if m is None:
                        test_message.append(False)
                    else:
                        test_message.append(m.startswith("LST values stored in "))
            if not any(test_message):
                warnings.filterwarnings("ignore", message="LST values stored in ")

        retval = func(*func_args, **func_kwargs)  # Run function
        # Verify
        if len(w) != nwarnings:
            print(
                "wrong number of warnings. Expected number was {nexp}, "
                "actual number was {nact}.".format(nexp=nwarnings, nact=len(w))
            )
            for idx, wi in enumerate(w):
                print("warning {i} is: {w}".format(i=idx, w=wi))
            assert False
        else:
            for i, w_i in enumerate(w):
                if w_i.category is not category[i]:
                    print("expected category " + str(i) + " was: ", category[i])
                    print("category " + str(i) + " was: ", str(w_i.category))
                    assert False
                if message[i] is None or message[i] == "":
                    print("Expected message " + str(i) + " was None or an empty string")
                    print("message " + str(i) + " was: ", str(w_i.message))
                    assert False
                else:
                    if message[i] not in str(w_i.message):
                        print("expected message " + str(i) + " was: ", message[i])
                        print("message " + str(i) + " was: ", str(w_i.message))
                        assert False
        return retval
back to top