grad_check.py
import numpy as np
from py_diff_stokes_flow.common.common import print_error, print_ok
# var_filter
def check_gradients(loss_and_grad, x0, eps=1e-6, rtol=1e-2, atol=1e-4, verbose=True,
loss_only=None, grad_only=None, skip_var=None):
if grad_only is None:
_, grad_analytic = loss_and_grad(x0)
else:
grad_analytic = grad_only(x0)
grads_equal = True
n = x0.size
for i in range(n):
if skip_var is not None and skip_var(i): continue
x_pos = np.copy(x0)
x_neg = np.copy(x0)
x_pos[i] += eps
x_neg[i] -= eps
if loss_only is None:
loss_pos, _ = loss_and_grad(x_pos)
loss_neg, _ = loss_and_grad(x_neg)
else:
loss_pos = loss_only(x_pos)
loss_neg = loss_only(x_neg)
grad_numeric = (loss_pos - loss_neg) / 2 / eps
if not np.isclose(grad_analytic[i], grad_numeric, rtol=rtol, atol=atol):
grads_equal = False
if verbose:
print_error('Variable {}: analytic {}, numeric {}'.format(i, grad_analytic[i], grad_numeric))
else:
return grads_equal
elif verbose:
print_ok('Variable {} seems good: analytic {}, numeric {}'.format(i, grad_analytic[i], grad_numeric))
return grads_equal