Source code for pylops.basicoperators.Spread

import logging
import numpy as np
from pylops import LinearOperator

try:
    from numba import jit
    from ._Spread_numba import _matvec_numba_table, _rmatvec_numba_table, \
        _matvec_numba_onthefly, _rmatvec_numba_onthefly
except ModuleNotFoundError:
    jit = None
    jit_message = 'Numba not available, reverting to numpy.'
except Exception as e:
    jit = None
    jit_message = 'Failed to import numba (error:%s), use numpy.' % e

logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING)


[docs]class Spread(LinearOperator): r"""Spread operator. Spread values from the input model vector arranged as a 2-dimensional array of size :math:`[n_{x0} \times n_{t0}]` into the data vector of size :math:`[n_x \times n_t]`. Spreading is performed along parametric curves provided as look-up table of pre-computed indices (``table``) or computed on-the-fly using a function handle (``fh``). In adjont mode, values from the data vector are instead stacked along the same parametric curves. Parameters ---------- dims : :obj:`tuple` Dimensions of model vector (vector will be reshaped internally into a two-dimensional array of size :math:`[n_{x0} \times n_{t0}]`, where the first dimension is the spreading/stacking direction) dimsd : :obj:`tuple` Dimensions of model vector (vector will be reshaped internal into a two-dimensional array of size :math:`[n_x \times n_t]`) table : :obj:`np.ndarray`, optional Look-up table of indeces of size :math:`[n_{x0} \times n_{t0} \times n_x]` (if ``None`` use function handle ``fh``) dtable : :obj:`np.ndarray`, optional Look-up table of decimals remainders for linear interpolation of size :math:`[n_{x0} \times n_{t0} \times n_x]` (if ``None`` use function handle ``fh``) fh : :obj:`np.ndarray`, optional Function handle that returns an index (and a fractional value in case of ``interp=True``) to be used for spreading/stacking given indices in :math:`x0` and :math:`t` axes (if ``None`` use look-up table ``table``) interp : :obj:`bool`, optional Apply linear interpolation (``True``) or nearest interpolation (``False``) during stacking/spreading along parametric curve. To be used only if ``engine='numba'``, inferred directly from the number of outputs of ``fh`` for ``engine='numpy'`` engine : :obj:`str`, optional Engine used for fft computation (``numpy`` or ``numba``). Note that ``numba`` can only be used when providing a look-up table dtype : :obj:`str`, optional Type of elements in input array. Attributes ---------- shape : :obj:`tuple` Operator shape explicit : :obj:`bool` Operator contains a matrix that can be solved explicitly (``True``) or not (``False``) Raises ------ KeyError If ``engine`` is neither ``numpy`` nor ``numba`` NotImplementedError If both ``table`` and ``fh`` are not provided ValueError If ``table`` has shape different from :math:`[n_{x0} \times n_t0 \times n_x]` Notes ----- The Spread operator applies the following linear transform in forward mode to the model vector after reshaping it into a 2-dimensional array of size :math:`[n_x \times n_t]`: .. math:: m(x0, t_0) \rightarrow d(x, t=f(x0, x, t_0)) where :math:`f(x0, x, t)` is a mapping function that returns a value t given values :math:`x0`, :math:`x`, and :math:`t_0`. In adjoint mode, the model is reconstructed by means of the following stacking operation: .. math:: m(x0, t_0) = \int{d(x, t=f(x0, x, t_0))} dx Note that ``table`` (or ``fh``) must return integer numbers representing indices in the axis :math:`t`. However it also possible to perform linear interpolation as part of the spreading/stacking process by providing the decimal part of the mapping function (:math:`t - \lfloor t \rfloor`) either in ``dtable`` input parameter or as second value in the return of ``fh`` function. """ def __init__(self, dims, dimsd, table=None, dtable=None, fh=None, interp=False, engine='numpy', dtype='float64'): if not engine in ['numpy', 'numba']: raise KeyError('engine must be numpy or numba') if engine == 'numba' and jit is not None: self.engine = 'numba' else: if engine == 'numba' and jit is None: logging.warning(jit_message) self.engine = 'numpy' # axes self.dims, self.dimsd = dims, dimsd self.nx0, self.nt0 = self.dims[0], self.dims[1] self.nx, self.nt = self.dimsd[0], self.dimsd[1] self.table = table self.dtable = dtable self.fh = fh # find out if mapping is in table of function handle if table is None and fh is None: raise NotImplementedError('provide either table or fh...') elif table is not None: if self.table.shape != (self.nx0, self.nt0, self.nx): raise ValueError('table must have shape [nx0 x nt0 x nx]') self.usetable = True if np.any(self.table > self.nt): raise ValueError('values in table must be smaller than nt') else: self.usetable = False # find out if linear interpolation has to be carried out self.interp = False if self.usetable: if dtable is not None: if self.dtable.shape != (self.nx0, self.nt0, self.nx): raise ValueError('dtable must have shape [nx0 x nt x nx]') self.interp = True else: if self.engine == 'numba': self.interp = interp else: if len(fh(0, 0)) == 2: self.interp = True self.shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) self.dtype = np.dtype(dtype) self.explicit = False def _matvec_numpy(self, x): x = x.reshape(self.dims) y = np.zeros(self.dimsd, dtype=self.dtype) for it in range(self.dims[1]): for ix0 in range(self.dims[0]): if self.usetable: indices = self.table[ix0, it] if self.interp: dindices = self.dtable[ix0, it] else: if self.interp: indices, dindices = self.fh(ix0, it) else: indices = self.fh(ix0, it) mask = np.argwhere(~np.isnan(indices)) if mask.size > 0: indices = (indices[mask]).astype(np.int) if not self.interp: y[mask, indices] += x[ix0, it] else: y[mask, indices] += (1-dindices[mask])*x[ix0, it] y[mask, indices + 1] += dindices[mask] * x[ix0, it] return y.ravel() def _rmatvec_numpy(self, x): x = x.reshape(self.dimsd) y = np.zeros(self.dims, dtype=self.dtype) for it in range(self.dims[1]): for ix0 in range(self.dims[0]): if self.usetable: indices = self.table[ix0, it] if self.interp: dindices = self.dtable[ix0, it] else: if self.interp: indices, dindices = self.fh(ix0, it) else: indices = self.fh(ix0, it) mask = np.argwhere(~np.isnan(indices)) if mask.size > 0: indices = (indices[mask]).astype(np.int) if not self.interp: y[ix0, it] = np.sum(x[mask, indices]) else: y[ix0, it] = \ np.sum(x[mask, indices]*(1-dindices[mask])) + \ np.sum(x[mask, indices+1]*dindices[mask]) return y.ravel() def _matvec(self, x): if self.engine == 'numba': y = np.zeros(self.dimsd, dtype=self.dtype) if self.usetable: y = _matvec_numba_table(x, y, self.dims, self.interp, self.table, self.table if self.dtable is None else self.dtable) else: y = _matvec_numba_onthefly(x, y, self.dims, self.interp, self.fh) else: y = self._matvec_numpy(x) return y def _rmatvec(self, x): if self.engine == 'numba': y = np.zeros(self.dims, dtype=self.dtype) if self.usetable: y = _rmatvec_numba_table(x, y, self.dims, self.dimsd, self.interp, self.table, self.table if self.dtable is None else self.dtable) else: y = _rmatvec_numba_onthefly(x, y, self.dims, self.dimsd, self.interp, self.fh) else: y = self._rmatvec_numpy(x) return y