__all__ = [
"Callbacks",
"CostNanInfCallback",
"CostToDataCallback",
"CostToInitialCallback",
"MetricsCallback",
]
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
import numpy as np
from pylops.utils.metrics import mae, mse, psnr, snr
from pylops.utils.typing import NDArray
if TYPE_CHECKING:
from pylops.linearoperator import LinearOperator
from pylops.optimization.basesolver import Solver
[docs]
class Callbacks:
r"""Callbacks
This is a template class which a user must subclass when implementing callbacks for a solver.
This class comprises of the following methods:
- ``on_setup_begin``: a method that is invoked at the start of the setup method of the solver
- ``on_setup_end``: a method that is invoked at the end of the setup method of the solver
- ``on_step_begin``: a method that is invoked at the start of the step method of the solver
- ``on_step_end``: a method that is invoked at the end of the setup step of the solver
- ``on_run_begin``: a method that is invoked at the start of the run method of the solver
- ``on_run_end``: a method that is invoked at the end of the run method of the solver
All methods take two input parameters: the solver itself, and the vector ``x``.
Moreover, some callback may be used to implement custom stopping criteria for the solver.
This can be done by adding a boolean attribute ``stop`` to the callback object, which will
be initially set to ``False``. As soon as the callback sets this attribute to ``True``, the
``run`` method of the solver will stop iterating and return the current model vector.
Examples
--------
>>> import numpy as np
>>> from pylops.basicoperators import MatrixMult
>>> from pylops.optimization.basic import CG
>>> from pylops.optimization.callback import Callbacks
>>>
>>> class StoreIterCallback(Callbacks):
... def __init__(self):
... self.stored = []
... def on_step_end(self, solver, x):
... self.stored.append(solver.iiter)
>>>
>>> Aop = MatrixMult(np.random.normal(0., 1., 36).reshape(6, 6))
>>> Aop = Aop.H @ Aop
>>> y = Aop @ np.ones(6)
>>> cb_sto = StoreIterCallback()
>>> cgsolve = CG(Aop, callbacks=[cb_sto, ])
>>> xest = cgsolve.solve(y=y, x0=np.zeros(6), tol=0, niter=6, show=False)[0]
>>> xest, cb_sto.stored
(array([1., 1., 1., 1., 1., 1.]), [1, 2, 3, 4, 5, 6])
"""
def __init__(self) -> None:
pass
[docs]
def on_setup_begin(self, solver: "Solver", x0: NDArray) -> None:
"""Callback before setup
Parameters
----------
solver : :obj:`pylops.optimization.basesolver.Solver`
Solver object
x0 : :obj:`numpy.ndarray`
Initial guess (when present as one of the inputs of the solver
setup method)
"""
pass
[docs]
def on_setup_end(self, solver: "Solver", x: NDArray) -> None:
"""Callback after setup
Parameters
----------
solver : :obj:`pylops.optimization.basesolver.Solver`
Solver object
x : :obj:`numpy.ndarray`
Current model vector
"""
pass
[docs]
def on_step_begin(self, solver: "Solver", x: NDArray) -> None:
"""Callback before step of solver
Parameters
----------
solver : :obj:`pylops.optimization.basesolver.Solver`
Solver object
x : :obj:`numpy.ndarray`
Current model vector
"""
pass
[docs]
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
"""Callback after step of solver
Parameters
----------
solver : :obj:`pylops.optimization.basesolver.Solver`
Solver object
x : :obj:`numpy.ndarray`
Current model vector
"""
pass
[docs]
def on_run_begin(self, solver: "Solver", x: NDArray) -> None:
"""Callback before entire solver run
Parameters
----------
solver : :obj:`pylops.optimization.basesolver.Solver`
Solver object
x : :obj:`numpy.ndarray`
Current model vector
"""
pass
[docs]
def on_run_end(self, solver: "Solver", x: NDArray) -> None:
"""Callback after entire solver run
Parameters
----------
solver : :obj:`pylops.optimization.basesolver.Solver`
Solver object
x : :obj:`numpy.ndarray`
Current model vector
"""
pass
[docs]
class CostToDataCallback(Callbacks):
"""Cost to data callback
This callback can be used to stop the solver when the ``cost`` parameter
of the solver is below a certain threshold defined as a percentage of the
Euclidean norm of the data.
Note that the meaning of ``cost`` can change from solver to solver - e.g.,
it can represent the misfit of the data term or the total cost function.
Parameters
----------
rtol : :obj:`float`
Percentage of the initial cost below which the solver
will stop iterating. For example, if ``rtol`` is 0.1, the solver
will stop when the cost is below 10% of the Euclidean norm of
the data.
"""
def __init__(self, rtol: float) -> None:
self.rtol = rtol
self.stop = False
def on_setup_end(self, solver: "Solver", x: NDArray) -> None:
self.ynorm = solver.ncp.linalg.norm(solver.y)
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if solver.cost[-1] < self.rtol * self.ynorm:
self.stop = True
[docs]
class CostToInitialCallback(Callbacks):
"""Cost to initial callback
This callback can be used to stop the solver when the ``cost``
parameter of the solver is below a certain threshold defined as a
percentage of the initial residual norm.
Note that the meaning of ``cost`` can change from solver to solver - e.g.,
it can represent the misfit of the data term or the total cost function.
Parameters
----------
rtol : :obj:`float`
Percentage of the initial cost below which the solver
will stop iterating. For example, if ``rtol`` is 0.1, the solver
will stop when the cost is below 10% of the initial
cost.
"""
def __init__(self, rtol: float) -> None:
self.rtol = rtol
self.stop = False
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if solver.cost[-1] < self.rtol * solver.cost[0]:
self.stop = True
[docs]
class CostNanInfCallback(Callbacks):
"""Cost Nan/Inf callback
This callback can be used to stop the solver when the ``cost``
becomes either ``np.nan`` or ``np.inf``
"""
def __init__(self) -> None:
self.stop = False
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if np.isnan(solver.cost[-1]) or np.isinf(solver.cost[-1]):
self.stop = True
[docs]
class MetricsCallback(Callbacks):
r"""Metrics callback
This callback can be used to store different metrics from the
``pylops.utils.metrics`` module during iterations.
Parameters
----------
xtrue : :obj:`numpy.ndarray`
True model vector
Op : :obj:`pylops.LinearOperator`, optional
Operator to apply to the solution prior to comparing it with `xtrue`
which : :obj:`tuple`, optional
List of metrics to compute (currently available: "mae", "mse", "snr",
and "psnr")
"""
def __init__(
self,
xtrue: NDArray,
Op: Optional["LinearOperator"] = None,
which: Sequence[str] = ("mae", "mse", "snr", "psnr"),
):
self.xtrue = xtrue
self.Op = Op
self.which = which
self.metrics: dict[str, list] = {}
if "mae" in self.which:
self.metrics["mae"] = []
if "mse" in self.which:
self.metrics["mse"] = []
if "snr" in self.which:
self.metrics["snr"] = []
if "psnr" in self.which:
self.metrics["psnr"] = []
def on_step_end(self, solver: "Solver", x: NDArray) -> None:
if self.Op is not None:
x = self.Op * x
if "mae" in self.which:
self.metrics["mae"].append(mae(self.xtrue, x))
if "mse" in self.which:
self.metrics["mse"].append(mse(self.xtrue, x))
if "snr" in self.which:
self.metrics["snr"].append(snr(self.xtrue, x))
if "psnr" in self.which:
self.metrics["psnr"].append(psnr(self.xtrue, x))
def _callback_stop(callbacks: Sequence[Callbacks]) -> bool:
"""Check if any callback has raised a stop flag
Parameters
----------
callbacks : :obj:`pylops.optimization.callback.Callbacks`
List of callbacks to evaluate
Returns
-------
stop : :obj:`bool`
Whether to stop the solver or not
"""
if callbacks is not None:
stop = [
False if not hasattr(callback, "stop") else callback.stop
for callback in callbacks
]
if any(stop):
return True
return False