Revision 8b22682704c00bd278c44dae1686f726d261b718 authored by Steven Murray on 09 January 2023, 21:13:32 UTC, committed by Steven Murray on 09 January 2023, 21:13:32 UTC
1 parent e377413
__init__.py
# -*- mode: python; coding: utf-8 -*-
# Copyright (c) 2018 Radio Astronomy Software Group
# Licensed under the 2-clause BSD License
"""Setup testing environment, define useful testing functions.
"""
import re
import sys
import warnings
from astropy.utils import iers
import pyuvdata.utils as uvutils
__all__ = ["check_warnings", "clearWarnings", "checkWarnings"]
class WarningsChecker(warnings.catch_warnings):
"""
A context manager to check raised warnings.
Adapted from pytest WarningsRecorder and WarningsChecker.
Parameters
----------
expected_warning : list of Warnings
List of expected warnings.
match : list of str or regex
List of strings to match warnings to.
"""
def __init__(self, expected_warning, match):
"""Check inputs and initalize object."""
super().__init__(record=True)
self._entered = False
self._list = []
msg = "exceptions must be derived from Warning, not %s"
if expected_warning is None:
expected_warning_list = expected_warning
elif isinstance(expected_warning, list):
for exc in expected_warning:
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
expected_warning_list = expected_warning
elif issubclass(expected_warning, Warning):
expected_warning_list = [expected_warning]
else:
raise TypeError(msg % type(expected_warning))
msg = "match must be a str, not %s"
if match is None:
match_list = None
elif isinstance(match, list):
for exc in match:
if not isinstance(exc, str):
raise TypeError(msg % type(exc))
match_list = match
elif isinstance(match, str):
match_list = [match]
else:
raise TypeError(msg % type(match))
self.expected_warning = expected_warning_list
self.match = match_list
@property
def warnlist(self):
"""The list of recorded warnings."""
return self._list
def __getitem__(self, i: int):
"""Get a recorded warning by index."""
return self._list[i]
def __iter__(self):
"""Iterate through the recorded warnings."""
return iter(self._list)
def __len__(self):
"""The number of recorded warnings."""
return len(self._list)
def pop(self, cls):
"""
Pop the first recorded warning, raise exception if not exists.
Parameters
----------
cls : Warning
Warning class to check for.
"""
for i, w in enumerate(self._list):
if issubclass(w.category, cls):
return self._list.pop(i)
__tracebackhide__ = True
raise AssertionError("%r not found in warning list" % cls)
def clear(self) -> None:
"""Clear the list of recorded warnings."""
self._list[:] = []
def __enter__(self):
if self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot enter %r twice" % self)
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
# Filter annoying Cython warnings that serve no good purpose. see numpy#432
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
# PLP: Temporarily add this filter as well.
# numpy v1.20 causes h5py to issue warnings. Must keep as is for now to
# avoid test failures, should change back later.
warnings.filterwarnings("ignore", message="Passing None into shape arguments")
# Filter iers warnings if iers.conf.auto_max_age is set to None, as we
# do in testing if the iers url is down. See conftest.py for more info.
if iers.conf.auto_max_age is None:
warnings.filterwarnings("ignore", message="failed to download")
warnings.filterwarnings("ignore", message="time is out of IERS range")
test_message = []
for message in self.match:
if message is None:
test_message.append(False)
else:
test_message.append(message.startswith("LST values stored in "))
if not any(test_message):
warnings.filterwarnings("ignore", message="LST values stored in ")
return self
def __exit__(self, exc_type=None, exc_val=None, exc_tb=None):
if not self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot exit %r without entering first" % self)
super().__exit__(exc_type, exc_val, exc_tb)
# Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable.
self._entered = False
# only check if we're not currently handling an exception
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is None:
assert len(self) == 0
else:
assert len(self) == len(self.expected_warning), (
f"{len(self.expected_warning)} warnings expected, "
f"{len(self)} warnings issued. The list of emitted warnings is: "
f"{[each.message for each in self]}."
)
for warn_i, exp_warn in enumerate(self.expected_warning):
if not any(issubclass(r.category, exp_warn) for r in self):
__tracebackhide__ = True
raise AssertionError(
"DID NOT WARN. No warnings of type {} was emitted. "
"The list of emitted warnings is: {}.".format(
self.expected_warning, [each.message for each in self]
)
)
elif self.match is not None:
for record in self:
if str(record.message).startswith(
self.match[warn_i]
) or re.compile(self.match[warn_i]).search(
str(record.message)
):
if issubclass(
record.category, self.expected_warning[warn_i]
):
break
else:
raise AssertionError(
f"No warnings of type {self.expected_warning[warn_i]} "
f"matching ('{self.match[warn_i]}') was "
"emitted. The list of emitted warnings is: "
f"{[each.message for each in self]}."
)
def check_warnings(expected_warning, match=None, nwarnings=None, *args, **kwargs):
"""
Assert that code raises a particular set of warnings, used as a context manager.
Similar to ``pytest.warns``, but allows for specifying multiple warnings.
It also better matches warning strings when the warning uses f-strings or
formating. Can be used as a drop-in replacement for ``pytest.warns`` if
only one warning is issued (if more are issued they will need to be added
to the input lists for this to pass).
Note that unlike the older checkWarnings function, the warnings can be passed
in any order, they do not have to match the order the warnings are raised
in the code.
To assert that there are no warnings raised by some code, set `expected_warning`
to None (i.e. `with check_warnings(None):`)
Parameters
----------
expected_warning : list of Warnings or Warning
List of expected warnings. If a single warning type or a length 1 list,
will be used for the type of all warnings.
match : str or regex or list of str or regex
List of strings or regexes to match warnings to. If a str or a length 1
list, will be used for all warnings.
nwarnings : int, optional
Option to specify that multiple of a single type of warning is expected.
Only used if category and match both only have one element.
"""
__tracebackhide__ = True
if not (
expected_warning is None
or isinstance(expected_warning, list)
or issubclass(expected_warning, Warning)
):
raise TypeError("expected_warning must be a list or be derived from Warning")
if match is not None and not isinstance(match, (list, str)):
raise TypeError("match must be a list or a string.")
if expected_warning is not None and not isinstance(expected_warning, list):
expected_warning_list = [expected_warning]
else:
expected_warning_list = expected_warning
if match is not None and not isinstance(match, list):
match_list = [match]
else:
match_list = match
if expected_warning is not None:
if (
len(expected_warning_list) > 1
and len(match_list) > 1
and len(expected_warning_list) != len(match_list)
):
raise ValueError(
"If expected_warning and match both have more than one element, "
"they must be the same length."
)
if len(expected_warning_list) > 1 or len(match_list) > 1:
nwarnings = max(len(expected_warning_list), len(match_list))
elif nwarnings is None:
nwarnings = 1
if len(expected_warning_list) < nwarnings:
expected_warning_list = expected_warning_list * nwarnings
if len(match_list) < nwarnings:
match_list = match_list * nwarnings
if not args:
if kwargs:
msg = "Unexpected keyword arguments passed to check_warnings: "
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return WarningsChecker(expected_warning_list, match_list)
else:
func = args[0]
if not callable(func):
raise TypeError(
"{!r} object (type: {}) must be callable".format(func, type(func))
)
with WarningsChecker(expected_warning_list, match_list):
return func(*args[1:], **kwargs)
# Functions that are useful for testing:
def clearWarnings():
"""Quick code to make warnings reproducible."""
for _, mod in list(sys.modules.items()):
try:
reg = getattr(mod, "__warningregistry__", None)
except ImportError:
continue
if reg:
reg.clear()
def checkWarnings(
func,
func_args=None,
func_kwargs=None,
nwarnings=1,
category=UserWarning,
message=None,
known_warning=None,
):
"""
Function to check expected warnings in tests.
Deprecated. Use check_warnings instead.
Useful for checking that appropriate warnings are raised and to capture
(and silence) warnings in tests.
Parameters
----------
func : function
Function or method to check warnings for.
func_args : list, optional
List of positional parameters to pass `func`
func_kwargs : dict, optional
Dict of keyword parameter to pass func. Keys are the parameter names,
values are the values to pass to the parameters.
nwarnings : int
Number of expected warnings.
category : warning type or list of warning types
Expected warning type(s). If a scalar is passed and `nwarnings` is
greater than one, the same category will be expected for all warnings.
message : str or list of str
Expected warning string(s). If a scalar is passed and `nwarnings` is
greater than one, the same warning string will be expected for all warnings.
known_warning : {'miriad', 'paper_uvfits', 'fhd'}, optional
Shorthand way to specify one of a standard set of warnings.
Returns
-------
Value returned by `func`
Raises
------
AssertionError
If the warning(s) raised by func do not match the expected values.
"""
warnings.warn(
"`checkWarnings` is deprecated, and will be removed in pyuvdata version "
"2.3. Use `check_warnings` instead.",
DeprecationWarning,
)
if func_args is None:
func_args = []
if func_kwargs is None:
func_kwargs = {}
if (not isinstance(category, list) or len(category) == 1) and nwarnings > 1:
if isinstance(category, list):
category = category * nwarnings
else:
category = [category] * nwarnings
if (not isinstance(message, list) or len(message) == 1) and nwarnings > 1:
if isinstance(message, list):
message = message * nwarnings
else:
message = [message] * nwarnings
if known_warning == "miriad":
# The default warnings for known telescopes when reading miriad files
category = [UserWarning]
message = [
"Altitude is not present in Miriad file, using known "
"location values for PAPER."
]
nwarnings = 1
elif known_warning == "paper_uvfits":
# The default warnings for known telescopes when reading uvfits files
category = [UserWarning] * 2
message = ["Required Antenna keyword 'FRAME'", "telescope_location is not set"]
nwarnings = 2
elif known_warning == "fhd":
category = [UserWarning]
message = ["Telescope location derived from obs"]
nwarnings = 1
category = uvutils._get_iterable(category)
message = uvutils._get_iterable(message)
clearWarnings()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # All warnings triggered
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
# Filter iers warnings if iers.conf.auto_max_age is set to None, as we
# do in testing if the iers url is down. See conftest.py for more info.
if iers.conf.auto_max_age is None:
warnings.filterwarnings("ignore", message="failed to download")
warnings.filterwarnings("ignore", message="time is out of IERS range")
if isinstance(message, str):
test_message = [message.startswith("LST values stored in ")]
else:
test_message = []
for m in message:
if m is None:
test_message.append(False)
else:
test_message.append(m.startswith("LST values stored in "))
if not any(test_message):
warnings.filterwarnings("ignore", message="LST values stored in ")
retval = func(*func_args, **func_kwargs) # Run function
# Verify
if len(w) != nwarnings:
print(
"wrong number of warnings. Expected number was {nexp}, "
"actual number was {nact}.".format(nexp=nwarnings, nact=len(w))
)
for idx, wi in enumerate(w):
print("warning {i} is: {w}".format(i=idx, w=wi))
raise AssertionError
else:
for i, w_i in enumerate(w):
if w_i.category is not category[i]:
print("expected category " + str(i) + " was: ", category[i])
print("category " + str(i) + " was: ", str(w_i.category))
raise AssertionError
if message[i] is None or message[i] == "":
print("Expected message " + str(i) + " was None or an empty string")
print("message " + str(i) + " was: ", str(w_i.message))
raise AssertionError
else:
if message[i] not in str(w_i.message):
print("expected message " + str(i) + " was: ", message[i])
print("message " + str(i) + " was: ", str(w_i.message))
raise AssertionError
return retval
Computing file changes ...