Source code for pylops.utils.decorators

__all__ = [
    "disable_ndarray_multiplication",
    "add_ndarray_support_to_solver",
    "reshaped",
    "count",
]

from functools import wraps
from typing import Callable, Optional

from pylops.config import disabled_ndarray_multiplication


[docs]def disable_ndarray_multiplication(func: Callable) -> Callable: """Decorator which disables ndarray multiplication. Parameters ---------- func : :obj:`callable` Generic function Returns ------- wrapper : :obj:`callable` Decorated function """ @wraps(func) def wrapper(*args, **kwargs): # SciPy-type signature with disabled_ndarray_multiplication(): out = func(*args, **kwargs) return out return wrapper
[docs]def add_ndarray_support_to_solver(func: Callable) -> Callable: """Decorator which converts a solver-type function that only supports a 1d-array into one that supports one (dimsd-shaped) ndarray. Parameters ---------- func : :obj:`callable` Solver type function. Its signature must be ``func(A, b, *args, **kwargs)``. Its output must be a result-type tuple: ``(xinv, ...)``. Returns ------- wrapper : :obj:`callable` Decorated function """ @wraps(func) def wrapper(A, b, *args, **kwargs): # SciPy-type signature x0flat = False if "x0" in kwargs and kwargs["x0"] is not None: if kwargs["x0"].ndim == 1: x0flat = True kwargs["x0"] = kwargs["x0"].ravel() with disabled_ndarray_multiplication(): res = list(func(A, b.ravel(), *args, **kwargs)) # reshape if x0 was provided unflattened # (unless the operator has forceflat=True) if not x0flat and not getattr(A, "forceflat", None): res[0] = res[0].reshape(getattr(A, "dims", (A.shape[1],))) return tuple(res) return wrapper
[docs]def reshaped( func: Optional[Callable] = None, forward: Optional[bool] = None, swapaxis: bool = False, ) -> Callable: """Decorator for the common reshape/flatten pattern used in many operators. Parameters ---------- func : :obj:`callable`, optional Function to be decorated when no arguments are provided forward : :obj:`bool`, optional Reshape to ``dims`` if True, otherwise to ``dimsd``. If not provided, the decorated function's name will be inspected to infer the mode. Any operator having a name with 'rmat' as substring or whose name is 'div' or '__truediv__' will reshape to ``dimsd``. swapaxis : :obj:`bool`, optional If True, swaps the last axis of the input array of the decorated function with ``self.axis``. Only use if the decorated LinearOperator has ``axis`` attribute. Notes ----- A ``_matvec`` (forward) function can be simplified from .. code-block:: python def _matvec(self, x): x = x.reshape(self.dims) x = x.swapaxes(self.axis, -1) y = do_things_to_reshaped_swapped(y) y = y.swapaxes(self.axis, -1) y = y.ravel() return y to .. code-block:: python @reshaped(swapaxis=True) def _matvec(self, x): y = do_things_to_reshaped_swapped(y) return y When the decorator has no arguments, it can be called without parenthesis, e.g.: .. code-block:: python @reshaped def _matvec(self, x): y = do_things_to_reshaped(y) return y """ def decorator(f): if forward is None: fwd = ( "rmat" not in f.__name__ and f.__name__ != "div" and f.__name__ != "__truediv__" ) else: fwd = forward inp_dims = "dims" if fwd else "dimsd" @wraps(f) def wrapper(self, x): x = x.reshape(getattr(self, inp_dims)) if swapaxis: x = x.swapaxes(self.axis, -1) y = f(self, x) if swapaxis: y = y.swapaxes(self.axis, -1) y = y.ravel() return y return wrapper if func is not None: return decorator(func) return decorator
def count( func: Optional[Callable] = None, forward: Optional[bool] = None, matmat: bool = False, ) -> Callable: """Decorator used to count the number of forward and adjoint performed by an operator. Parameters ---------- func : :obj:`callable`, optional Function to be decorated when no arguments are provided forward : :obj:`bool`, optional Whether to count the forward (``True``) or adjoint (``False``). If not provided, the decorated function's name will be inspected to infer the mode. Any operator having a name with 'rmat' as substring will be defaulted to False matmat : :obj:`bool`, optional Whether to count the matmat (``True``) or matvec (``False``). If not provided, the decorated function's name will be inspected to infer the mode. Any operator having a name with 'matvec' as substring will be defaulted to False """ def decorator(f): if forward is None: fwd = "rmat" not in f.__name__ else: fwd = forward if matmat is None: mat = "matvec" not in f.__name__ else: mat = matmat @wraps(f) def wrapper(self, x): # perform operation y = f(self, x) # increase count of the associated operation if fwd: if mat: self.matmat_count += 1 self.matvec_count -= x.shape[-1] else: self.matvec_count += 1 else: if mat: self.rmatmat_count += 1 self.rmatvec_count -= x.shape[-1] else: self.rmatvec_count += 1 return y return wrapper if func is not None: return decorator(func) return decorator