Source code for pylops.signalprocessing.dwt

__all__ = ["DWT"]

from math import ceil, log

import numpy as np

from pylops import LinearOperator
from pylops.basicoperators import Pad
from pylops.utils import deps
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

pywt_message = deps.pywt_import("the dwt module")

if pywt_message is None:
    import pywt


def _checkwavelet(wavelet: str) -> None:
    """Check that wavelet belongs to pywt.wavelist"""
    wavelist = pywt.wavelist(kind="discrete")
    if wavelet not in wavelist:
        msg = f"'{wavelet}' not in family set = {wavelist}"
        raise ValueError(msg)


def _adjointwavelet(wavelet: str) -> str:
    """Define adjoint wavelet"""
    waveletadj = wavelet
    if "rbio" in wavelet:
        waveletadj = "bior" + wavelet[-3:]
    elif "bior" in wavelet:
        waveletadj = "rbio" + wavelet[-3:]
    return waveletadj


[docs] class DWT(LinearOperator): """One dimensional Wavelet operator. Apply 1D-Wavelet Transform along an ``axis`` of a multi-dimensional array of size ``dims``. Note that the Wavelet operator is an overload of the ``pywt`` implementation of the wavelet transform. Refer to https://pywavelets.readthedocs.io for a detailed description of the input parameters. Parameters ---------- dims : :obj:`int` or :obj:`tuple` Number of samples for each dimension axis : :obj:`int`, optional .. versionadded:: 2.0.0 Axis along which DWT is applied wavelet : :obj:`str`, optional Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for a list of available wavelets. level : :obj:`int`, optional Number of scaling levels (must be >=0). dtype : :obj:`str`, optional Type of elements in input array. name : :obj:`str`, optional .. versionadded:: 2.0.0 Name of operator (to be used by :func:`pylops.utils.describe.describe`) Attributes ---------- pad : :obj:`pylops.basicoperators.Pad` Padding operator used to pad the input signal to the next power of 2 length. waveletadj : :obj:`str` Name of the adjoint wavelet type. sl : :obj:`list` List of slices to reconstruct the wavelet coefficients from the raveled array. dims : :obj:`tuple` Shape of the array after the adjoint, but before flattening. For example, ``x_reshaped = (Op.H * y.ravel()).reshape(Op.dims)``. dimsd : :obj:`tuple` Shape of the array after the forward, but before flattening. For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``. shape : :obj:`tuple` Operator shape. Raises ------ ModuleNotFoundError If ``pywt`` is not installed ValueError If ``wavelet`` does not belong to ``pywt.families`` Notes ----- The Wavelet operator applies the multilevel Discrete Wavelet Transform (DWT) in forward mode and the multilevel Inverse Discrete Wavelet Transform (IDWT) in adjoint mode. Wavelet transforms can be used to compress signals and present a key advantage over Fourier transforms in that they captures both frequency and location information in time. Consider using this operator as sparsifying transform when using L1 solvers. """ def __init__( self, dims: int | InputDimsLike, axis: int = -1, wavelet: str = "haar", level: int = 1, dtype: DTypeLike = "float64", name: str = "D", ) -> None: if pywt_message is not None: raise ModuleNotFoundError(pywt_message) _checkwavelet(wavelet) dims = _value_or_sized_to_tuple(dims) # define padding for length to be power of 2 ndimpow2 = max(2 ** ceil(log(dims[axis], 2)), 2**level) pad = [(0, 0)] * len(dims) pad[axis] = (0, ndimpow2 - dims[axis]) self.pad = Pad(dims, pad) self.axis = axis dimsd = list(dims) dimsd[self.axis] = ndimpow2 super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name) # apply transform to find out slices _, self.sl = pywt.coeffs_to_array( pywt.wavedecn( np.ones(self.dimsd), wavelet=wavelet, level=level, mode="periodization", axes=(self.axis,), ), axes=(self.axis,), ) self.wavelet = wavelet self.waveletadj = _adjointwavelet(wavelet) self.level = level def _matvec(self, x: NDArray) -> NDArray: x = self.pad.matvec(x) x = np.reshape(x, self.dimsd) y = pywt.coeffs_to_array( pywt.wavedecn( x, wavelet=self.wavelet, level=self.level, mode="periodization", axes=(self.axis,), ), axes=(self.axis,), )[0] return y.ravel() def _rmatvec(self, x: NDArray) -> NDArray: x = np.reshape(x, self.dimsd) x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn") y = pywt.waverecn( x, wavelet=self.waveletadj, mode="periodization", axes=(self.axis,) ) y = self.pad.rmatvec(y.ravel()) return y