# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Standardized tests for `Operator`'s."""
import numpy as np
from odl.core.diagnostics.examples import samples
from odl.core.operator import power_method_opnorm
from odl.core.util.testutils import fail_counter
__all__ = ('OperatorTest',)
[docs]
class OperatorTest:
"""Automated tests for `Operator` implementations.
This class allows users to automatically test various
features of an Operator such as linearity, the adjoint definition and
definition of the derivative.
"""
[docs]
def __init__(self, operator, operator_norm=None, verbose=True, tol=1e-5):
"""Initialize a new instance.
Parameters
----------
operator : `Operator`
The operator to run tests on
operator_norm : float, optional
The norm of the operator, used for error estimates. If
``None`` is given, the norm is estimated during
initialization.
verbose : bool, optional
If ``True``, print additional info text.
tol : float, optional
Tolerance parameter used as a base for the actual tolerance
in the tests. Depending on the expected accuracy, the actual
tolerance used in a test can be a factor times this number.
"""
self.operator = operator
self.verbose = False
if operator_norm is None:
self.operator_norm = self.norm()
else:
self.operator_norm = float(operator_norm)
self.verbose = bool(verbose)
self.tol = float(tol)
[docs]
def log(self, message):
"""Print message if ``self.verbose == True``."""
if self.verbose:
print(message)
[docs]
def norm(self):
"""Estimate the operator norm of the operator.
The norm is estimated by calculating
``A(x).norm() / x.norm()``
for some nonzero ``x``
Returns
-------
norm : float
Estimate of operator norm
References
----------
Wikipedia article on `Operator norm
<https://en.wikipedia.org/wiki/Operator_norm>`_.
"""
self.log("\n== Calculating operator norm ==\n")
operator_norm = max(
power_method_opnorm(self.operator, maxiter=2, xstart=x)
for name, x in samples(self.operator.domain)
if name != 'Zero'
)
self.log(f"Norm is at least: {operator_norm}")
self.operator_norm = operator_norm
return operator_norm
[docs]
def self_adjoint(self):
"""Verify ``<Ax, y> == <x, Ay>``."""
left_inner_vals = []
right_inner_vals = []
with fail_counter(
test_name="Verifying the identity <Ax, y> = <x, Ay>",
err_msg="error = |<Ax, y> - <x, Ay>| / ||A|| ||x|| ||y||",
logger=self.log,
) as counter:
for [name_x, x], [name_y, y] in samples(self.operator.domain,
self.operator.range):
x_norm = x.norm()
y_norm = y.norm()
l_inner = self.operator(x).inner(y)
r_inner = x.inner(self.operator(y))
denom = self.operator_norm * x_norm * y_norm
error = 0 if denom == 0 else abs(l_inner - r_inner) / denom
if error > self.tol:
counter.fail(f"x={name_x:25s} y={name_y:25s} : error={error:6.5f}")
left_inner_vals.append(l_inner)
right_inner_vals.append(r_inner)
scale = np.polyfit(left_inner_vals, right_inner_vals, 1)[0]
self.log("\nThe adjoint seems to be scaled according to:")
self.log(f"(x, Ay) / (Ax, y) = {scale}. Should be 1.0")
def _adjoint_definition(self):
"""Verify ``<Ax, y> == <x, A^* y>``."""
left_inner_vals = []
right_inner_vals = []
with fail_counter(
test_name="Verifying the identity <Ax, y> = <x, A^T y>",
err_msg="error = |<Ax, y< - <x, A^* y>| / ||A|| ||x|| ||y||",
logger=self.log,
) as counter:
for [name_x, x], [name_y, y] in samples(self.operator.domain,
self.operator.range):
x_norm = x.norm()
y_norm = y.norm()
l_inner = self.operator(x).inner(y)
r_inner = x.inner(self.operator.adjoint(y))
denom = self.operator_norm * x_norm * y_norm
error = 0 if denom == 0 else abs(l_inner - r_inner) / denom
if error > self.tol:
counter.fail(f"x={name_x:25s} y={name_y:25s} : error={error:6.5f}")
left_inner_vals.append(l_inner)
right_inner_vals.append(r_inner)
scale = np.polyfit(left_inner_vals, right_inner_vals, 1)[0]
self.log("\nThe adjoint seems to be scaled according to:")
self.log(f"(x, A^T y) / (Ax, y) = {scale}. Should be 1.0")
def _adjoint_of_adjoint(self):
"""Verify ``(A^*)^* == A``"""
try:
self.operator.adjoint.adjoint
except AttributeError:
print("A^* has no adjoint")
return
if self.operator.adjoint.adjoint is self.operator:
self.log("(A^*)^* == A")
return
with fail_counter(
test_name="\nVerifying the identity Ax = (A^*)^* x",
err_msg="error = ||Ax - (A^*)^* x|| / ||A|| ||x||",
logger=self.log,
) as counter:
for [name_x, x] in self.operator.domain.examples:
opx = self.operator(x)
op_adj_adj_x = self.operator.adjoint.adjoint(x)
denom = self.operator_norm * x.norm()
if denom == 0:
error = 0
else:
error = (opx - op_adj_adj_x).norm() / denom
if error > self.tol:
counter.fail(f"x={name_x:25s} : error={error:6.5f}")
[docs]
def adjoint(self):
"""Verify that `Operator.adjoint` works appropriately.
References
----------
Wikipedia article on `Adjoint
<https://en.wikipedia.org/wiki/Adjoint>`_.
"""
try:
self.operator.adjoint
except NotImplementedError:
print("Operator has no adjoint")
return
self.log("\n== Verifying operator adjoint ==\n")
domain_range_ok = True
if self.operator.domain != self.operator.adjoint.range:
print("*** ERROR: A.domain != A.adjoint.range ***")
domain_range_ok = False
if self.operator.range != self.operator.adjoint.domain:
print("*** ERROR: A.range != A.adjoint.domain ***")
domain_range_ok = False
if domain_range_ok:
self.log("Domain and range of adjoint are OK.")
else:
print("Domain and range of adjoint are not OK, exiting.")
return
self._adjoint_definition()
self._adjoint_of_adjoint()
def _derivative_convergence(self):
"""Verify that the derivative is a first-order approximation.
The code verifies if
``||A(x+c*p) - A(x) - A'(x)(c*p)|| / c = o(c)``
for ``c --> 0``.
"""
with fail_counter(
test_name="Verifying that derivative is a first-order approximation",
err_msg="error = inf_c ||A(x+c*p)-A(x)-A'(x)(c*p)|| / c",
logger=self.log
) as counter:
for [name_x, x], [name_dx, dx] in samples(self.operator.domain,
self.operator.domain):
# Precompute some values
deriv = self.operator.derivative(x)
derivdx = deriv(dx)
opx = self.operator(x)
c = 1e-4 # initial step
derivative_ok = False
minerror = float('inf')
while c > 1e-14:
exact_step = self.operator(x + dx * c) - opx
expected_step = c * derivdx
err = (exact_step - expected_step).norm() / c
# Need to be slightly more generous here due to possible
# numerical instabilities.
# TODO: perform more tests to find a good threshold here.
if err < 10 * self.tol:
derivative_ok = True
break
else:
minerror = min(minerror, err)
c /= 10.0
if not derivative_ok:
counter.fail(f"x={name_x:15s} p={name_dx:15s}, error={minerror}")
[docs]
def derivative(self):
"""Verify that `Operator.derivative` works appropriately.
The code verifies if
``||A(x+c*p) - A(x) - A'(x)(c*p)|| / c = o(c)``
for ``c --> 0`` using a selection of elements ``x`` and ``p``.
References
----------
Wikipedia article on `Derivative
<https://en.wikipedia.org/wiki/Derivative>`_.
Wikipedia article on `Frechet derivative
<https://en.wikipedia.org/wiki/Fr%C3%A9chet_derivative>`_.
"""
self.log("\n== Verifying operator derivative ==")
try:
deriv = self.operator.derivative(self.operator.domain.zero())
if not deriv.is_linear:
print("Derivative is not a linear operator")
return
except NotImplementedError:
print("Operator has no derivative")
return
if self.operator.is_linear and deriv is self.operator:
self.log("A is linear and A.derivative is A")
return
self._derivative_convergence()
def _scale_invariance(self):
"""Verify ``A(c*x) = c * A(x)``."""
with fail_counter(
test_name="Verifying homogeneity under scalar multiplication",
err_msg="error = ||A(c*x)-c*A(x)|| / |c| ||A|| ||x||",
logger=self.log,
) as counter:
for [name_x, x], [_, scale] in samples(self.operator.domain,
self.operator.domain.field):
opx = self.operator(x)
scaled_opx = self.operator(scale * x)
denom = self.operator_norm * scale * x.norm()
error = (0 if denom == 0
else (scaled_opx - opx * scale).norm() / denom)
if error > self.tol:
counter.fail(
f"x={name_x:25s} scale={scale:7.2f} error={error:6.5f}"
)
def _addition_invariance(self):
"""Verify ``A(x+y) = A(x) + A(y)``."""
with fail_counter(
test_name="Verifying distributivity under vector addition",
err_msg="error = ||A(x+y) - A(x) - A(y)|| / " "||A||(||x|| + ||y||)",
logger=self.log,
) as counter:
for [name_x, x], [name_y, y] in samples(self.operator.domain,
self.operator.domain):
opx = self.operator(x)
opy = self.operator(y)
opxy = self.operator(x + y)
denom = self.operator_norm * (x.norm() + y.norm())
error = (0 if denom == 0
else (opxy - opx - opy).norm() / denom)
if error > self.tol:
counter.fail(f"x={name_x:25s} y={name_y:25s} error={error:6.5f}")
[docs]
def linear(self):
"""Verify that the operator is actually linear."""
if not self.operator.is_linear:
print("Operator is not linear")
return
self.log("\n== Verifying operator linearity ==\n")
# Test if zero gives zero
result = self.operator(self.operator.domain.zero())
result_norm = result.norm()
if result_norm != 0.0:
print(f"||A(0)||={result_norm:6.5f}. Should be 0.0000")
self._scale_invariance()
self._addition_invariance()
[docs]
def run_tests(self):
"""Run all tests on this operator."""
print("\n== RUNNING ALL TESTS ==")
print(f"Operator = {self.operator}")
self.norm()
if self.operator.is_linear:
self.linear()
self.adjoint()
else:
self.derivative()
def __str__(self):
return f"{self.__class__.__name__}({self.operator})"
def __repr__(self):
return f"{self.__class__.__name__}({self.operator})"
if __name__ == '__main__':
import odl
space = odl.uniform_discr([0, 0], [1, 1], [3, 3])
# Linear operator
I = odl.IdentityOperator(space)
OperatorTest(I, verbose=False).run_tests()
# Nonlinear operator op(x) = x**4
op = odl.PowerOperator(space, 4)
OperatorTest(op).run_tests()