__all__ = ["Spread"]
import logging
from typing import Callable, Optional
import numpy as np
from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
jit_message = deps.numba_import("the spread module")
if jit_message is None:
from numba import jit
from ._spread_numba import (
_matvec_numba_onthefly,
_matvec_numba_table,
_rmatvec_numba_onthefly,
_rmatvec_numba_table,
)
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
[docs]class Spread(LinearOperator):
r"""Spread operator.
Spread values from the input model vector arranged as a 2-dimensional
array of size :math:`[n_{x_0} \times n_{t_0}]` into the data vector of size
:math:`[n_x \times n_t]`. Note that the value at each single pair
:math:`(x_0, t_0)` in the input is spread over the entire :math:`x` axis
in the output.
Spreading is performed along parametric curves provided as look-up table
of pre-computed indices (``table``) or computed on-the-fly using a
function handle (``fh``).
In adjont mode, values from the data vector are instead stacked
along the same parametric curves.
Parameters
----------
dims : :obj:`tuple`
Dimensions of model vector (vector will be reshaped internally into
a two-dimensional array of size :math:`[n_{x_0} \times n_{t_0}]`,
where the first dimension is the spreading direction)
dimsd : :obj:`tuple`
Dimensions of data vector (vector will be reshaped internal into
a two-dimensional array of size :math:`[n_x \times n_t]`,
where the first dimension is the stacking direction)
table : :obj:`np.ndarray`, optional
Look-up table of indices of size
:math:`[n_{x_0} \times n_{t_0} \times n_x]` (if ``None`` use function
handle ``fh``). When ``dtable`` is not provided, the ``data`` will be created
as follows
.. code-block:: python
data[ix, table[ix0, it0, ix]] += model[ix0, it0]
.. note::
When using ``table`` without ``dtable``, its elements must be
between 0 and :math:`n_{t_0} - 1` (or ``numpy.nan``).
dtable : :obj:`np.ndarray`, optional
Look-up table of decimals remainders for linear interpolation of size
:math:`[n_{x_0} \times n_{t_0} \times n_x]` (if ``None`` use function
handle ``fh``). When provided, the ``data`` will be created as follows
.. code-block:: python
data[ix, table[ix0, it0, ix]] += (1 - dtable[ix0, it0, ix]) * model[ix0, it0]
data[ix, table[ix0, it0, ix] + 1] += dtable[ix0, it0, ix] * model[ix0, it0]
.. note::
When using ``table`` and ``dtable``, the elements of ``table`` indices must be
between 0 and :math:`n_{t_0} - 2` (or ``numpy.nan``).
fh : :obj:`callable`, optional
If ``None`` will use look-up table ``table``. When provided, should be a
function which takes indices ``ix0`` and ``it0`` and returns
an array of size :math:`n_x` containing each respective time index.
Alternatively, if linear interpolation is required, it should output in
addition to the time indices, a weight for interpolation with linear
interpolation, to be used as follows
.. code-block:: python
data[ix, index] += (1 - dindices[ix]) * model[ix0, it0]
data[ix, index + 1] += dindices[ix] * model[ix0, it0]
where ``index`` refers to a time index in the first array returned by ``fh``
and ``dindices`` refers to the weight in the second array returned by ``fh``.
.. note::
When using ``fh`` with one output (time indices), the time indices must be
between 0 and :math:`n_{t_0} - 1` (or ``numpy.nan``). When using ``fh`` with two outputs
(time indices and weights), they must be within the between 0 and
:math:`n_{t_0} - 2` (or ``numpy.nan``).
interp : :obj:`bool`, optional
Use only if engine ``engine='numba'``. Apply linear interpolation (``True``)
or nearest interpolation (``False``) during stacking/spreading along
parametric curve.
When using ``engine="numpy"``, it will be inferred directly from ``fh`` or
the presence of ``dtable``.
engine : :obj:`str`, optional
Engine used for spread computation (``numpy`` or ``numba``). Note that
``numba`` can only be used when providing a look-up table
dtype : :obj:`str`, optional
Type of elements in input array.
name : :obj:`str`, optional
.. versionadded:: 2.0.0
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
Attributes
----------
shape : :obj:`tuple`
Operator shape
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly (``True``) or
not (``False``)
Raises
------
KeyError
If ``engine`` is neither ``numpy`` nor ``numba``
NotImplementedError
If both ``table`` and ``fh`` are not provided
ValueError
If ``table`` has shape different from
:math:`[n_{x_0} \times n_{t_0} \times n_x]`
Notes
-----
The Spread operator applies the following linear transform in forward mode
to the model vector after reshaping it into a 2-dimensional array of size
:math:`[n_x \times n_t]`:
.. math::
m(x_0, t_0) \rightarrow d(x, t=f(x_0, x, t_0)) \quad \forall x
where :math:`f(x_0, x, t)` is a mapping function that returns a value :math:`t`
given values :math:`x_0`, :math:`x`, and :math:`t_0`. Note that for each
:math:`(x_0, t_0)` pair, spreading is done over the entire :math:`x` axis
in the data domain.
In adjoint mode, the model is reconstructed by means of the following
stacking operation:
.. math::
m(x_0, t_0) = \int{d(x, t=f(x_0, x, t_0))} \,\mathrm{d}x
Note that ``table`` (or ``fh``) must return integer numbers
representing indices in the axis :math:`t`. However it also possible to
perform linear interpolation as part of the spreading/stacking process by
providing the decimal part of the mapping function (:math:`t - \lfloor
t \rfloor`) either in ``dtable`` input parameter or as second value in
the return of ``fh`` function.
"""
def __init__(
self,
dims: InputDimsLike,
dimsd: InputDimsLike,
table: Optional[NDArray] = None,
dtable: Optional[NDArray] = None,
fh: Optional[Callable] = None,
interp: Optional[bool] = None,
engine: str = "numpy",
dtype: DTypeLike = "float64",
name: str = "S",
) -> None:
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
if engine not in ["numpy", "numba"]:
raise KeyError("engine must be numpy or numba")
if engine == "numba" and jit_message is None:
self.engine = "numba"
else:
if engine == "numba" and jit is not None:
logging.warning(jit_message)
self.engine = "numpy"
# axes
self.nx0, self.nt0 = self.dims[0], self.dims[1]
self.nx, self.nt = self.dimsd[0], self.dimsd[1]
self.table = table
self.dtable = dtable
self.fh = fh
# find out if mapping is in table of function handle
if self.table is None and fh is None:
raise NotImplementedError("provide either table or fh.")
elif self.table is not None:
if fh is not None:
raise ValueError("provide only one of table or fh.")
if self.table.shape != (self.nx0, self.nt0, self.nx):
raise ValueError("table must have shape [nx0 x nt0 x nx]")
self.usetable = True
if np.any(self.table > self.nt):
raise ValueError("values in table must be smaller than nt")
else:
self.usetable = False
# find out if linear interpolation has to be carried out
self.interp = False
if self.usetable:
if self.dtable is not None:
if self.dtable.shape != (self.nx0, self.nt0, self.nx):
raise ValueError("dtable must have shape [nx0 x nt x nx]")
self.interp = True
else:
if self.engine == "numba":
self.interp = interp
else:
if len(fh(0, 0)) == 2:
self.interp = True
if interp is not None and self.interp != interp:
logging.warning("interp has been overridden to %r.", self.interp)
def _matvec_numpy(self, x: NDArray) -> NDArray:
y = np.zeros(self.dimsd, dtype=self.dtype)
for it in range(self.dims[1]):
for ix0 in range(self.dims[0]):
if self.usetable:
indices = self.table[ix0, it]
if self.interp:
dindices = self.dtable[ix0, it]
else:
if self.interp:
indices, dindices = self.fh(ix0, it)
else:
indices = self.fh(ix0, it)
mask = np.argwhere(~np.isnan(indices))
if mask.size > 0:
indices = (indices[mask]).astype(int)
if not self.interp:
y[mask, indices] += x[ix0, it]
else:
y[mask, indices] += (1 - dindices[mask]) * x[ix0, it]
y[mask, indices + 1] += dindices[mask] * x[ix0, it]
return y
def _rmatvec_numpy(self, x: NDArray) -> NDArray:
y = np.zeros(self.dims, dtype=self.dtype)
for it in range(self.dims[1]):
for ix0 in range(self.dims[0]):
if self.usetable:
indices = self.table[ix0, it]
if self.interp:
dindices = self.dtable[ix0, it]
else:
if self.interp:
indices, dindices = self.fh(ix0, it)
else:
indices = self.fh(ix0, it)
mask = np.argwhere(~np.isnan(indices))
if mask.size > 0:
indices = (indices[mask]).astype(int)
if not self.interp:
y[ix0, it] = np.sum(x[mask, indices])
else:
y[ix0, it] = np.sum(
x[mask, indices] * (1 - dindices[mask])
) + np.sum(x[mask, indices + 1] * dindices[mask])
return y
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
if self.engine == "numba":
y = np.zeros(self.dimsd, dtype=self.dtype)
if self.usetable:
y = _matvec_numba_table(
x,
y,
self.dims,
self.interp,
self.table,
self.table if self.dtable is None else self.dtable,
)
else:
y = _matvec_numba_onthefly(x, y, self.dims, self.interp, self.fh)
else:
y = self._matvec_numpy(x)
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
if self.engine == "numba":
y = np.zeros(self.dims, dtype=self.dtype)
if self.usetable:
y = _rmatvec_numba_table(
x,
y,
self.dims,
self.dimsd,
self.interp,
self.table,
self.table if self.dtable is None else self.dtable,
)
else:
y = _rmatvec_numba_onthefly(
x, y, self.dims, self.dimsd, self.interp, self.fh
)
else:
y = self._rmatvec_numpy(x)
return y