__all__ = ["MatrixMult"]
import logging
import warnings
import numpy as np
import scipy as sp
from scipy.sparse.linalg import inv
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_array
from pylops.utils.backend import get_array_module
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
logger = logging.getLogger(__name__)
[docs]
class MatrixMult(LinearOperator):
r"""Matrix multiplication.
Simple wrapper to :py:func:`numpy.dot` and :py:func:`numpy.vdot` for
an input matrix :math:`\mathbf{A}`.
Parameters
----------
A : :obj:`numpy.ndarray` or :obj:`scipy.sparse` matrix
Matrix.
otherdims : :obj:`tuple`, optional
Number of samples for each other dimension of model
(model/data will be reshaped and ``A`` applied multiple times
to each column of the model/data).
forceflat : :obj:`bool`, optional
.. versionadded:: 2.2.0
Force an array to be flattened after matvec and rmatvec. Note that this is only
required when `otherdims=None`, otherwise pylops will detect whether to
return a 1d or nd array.
dtype : :obj:`str`, optional
Type of elements in input array.
name : :obj:`str`, optional
.. versionadded:: 2.0.0
Name of operator (to be used by :func:`pylops.utils.describe.describe`).
Attributes
----------
dimsflatten : :obj:`tuple`
Same as ``dims`` but with first dimension flattened (i.e.,
defined as the product of ``otherdims``).
dimsdflatten : :obj:`tuple`
Same as ``dimsd`` but with first dimension flattened (i.e.,
defined as the product of ``otherdims``).
reshape : :obj:`bool`
Whether to reshape the input prior to applying the matrix ``A``
(when ``otherdims`` is provided) or not (when ``otherdims=None``).
complex : :obj:`bool`
Matrix has complex numbers (``True``) or not (``False``).
dims : :obj:`tuple`
Shape of the array after the adjoint, but before flattening.
For example, ``x_reshaped = (Op.H * y.ravel()).reshape(Op.dims)``.
dimsd : :obj:`tuple`
Shape of the array after the forward, but before flattening.
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
shape : :obj:`tuple`
Operator shape.
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly
(``True``) or not (``False``).
"""
def __init__(
self,
A: NDArray,
otherdims: int | InputDimsLike | None = None,
forceflat: bool | None = None,
dtype: DTypeLike = "float64",
name: str = "M",
) -> None:
ncp = get_array_module(A)
self.A = A
if isinstance(A, ncp.ndarray):
self.complex = np.iscomplexobj(A)
else:
self.complex = np.iscomplexobj(A.data)
if otherdims is None:
dims, dimsd = (A.shape[1],), (A.shape[0],)
self.reshape = False
explicit = True
else:
otherdims = _value_or_sized_to_array(otherdims)
self.otherdims = np.array(otherdims, dtype=int)
dims, dimsd = (
np.insert(self.otherdims, 0, self.A.shape[1]),
np.insert(self.otherdims, 0, self.A.shape[0]),
)
self.dimsflatten, self.dimsdflatten = (
np.insert([np.prod(self.otherdims)], 0, self.A.shape[1]),
np.insert([np.prod(self.otherdims)], 0, self.A.shape[0]),
)
self.reshape = True
explicit = False
# Check if forceflat is needed and set it back to None otherwise
if otherdims is not None and forceflat is not None:
logger.warning(
"Setting forceflat=None since otherdims!=None. "
"PyLops will automatically detect whether to return "
"a 1d or nd array based on the shape of the input "
"array."
)
forceflat = None
# Check dtype for correctness (upcast to complex when A is complex)
if np.iscomplexobj(A) and not np.iscomplexobj(np.ones(1, dtype=dtype)):
dtype = A.dtype
warnings.warn(
"Matrix A is a complex object, dtype cast to %s" % dtype, stacklevel=2
)
super().__init__(
dtype=np.dtype(dtype),
dims=dims,
dimsd=dimsd,
explicit=explicit,
forceflat=forceflat,
name=name,
)
def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.reshape:
x = ncp.reshape(x, self.dimsflatten)
y = self.A.dot(x)
if self.reshape:
return y.ravel()
else:
return y
def _rmatvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.reshape:
x = ncp.reshape(x, self.dimsdflatten)
if self.complex:
y = (self.A.T.dot(x.conj())).conj()
else:
y = self.A.T.dot(x)
if self.reshape:
return y.ravel()
else:
return y
def inv(self) -> NDArray:
r"""Return the inverse of :math:`\mathbf{A}`.
Returns
-------
Ainv : :obj:`numpy.ndarray`
Inverse matrix.
"""
if sp.sparse.issparse(self.A):
Ainv = inv(self.A)
else:
ncp = get_array_module(self.A)
Ainv = ncp.linalg.inv(self.A)
return Ainv