__all__ = ["DWTND"]
from math import ceil, log
import numpy as np
from pylops import LinearOperator
from pylops.basicoperators import Pad
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
from .dwt import _adjointwavelet, _checkwavelet
pywt_message = deps.pywt_import("the dwtnd module")
if pywt_message is None:
import pywt
[docs]class DWTND(LinearOperator):
"""N-dimensional Wavelet operator.
Apply ND-Wavelet transform along N ``axes`` of a
multi-dimensional array of size ``dims``.
Note that the Wavelet operator is an overload of the ``pywt``
implementation of the wavelet transform. Refer to
https://pywavelets.readthedocs.io for a detailed description of the
input parameters.
Defaults to a 3D wavelet transform along the last three dimensions
of the input array.
Parameters
----------
dims : :obj:`tuple`
Number of samples for each dimension
axes : :obj:`int`, optional
Axis along which DWTND is applied
wavelet : :obj:`str`, optional
Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for
a list of available wavelets.
level : :obj:`int`, optional
Number of scaling levels (must be >=0).
dtype : :obj:`str`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
Attributes
----------
pad : :obj:`pylops.basicoperators.Pad`
Padding operator used to pad the input signal to the next power of 2
length.
waveletadj : :obj:`str`
Name of the adjoint wavelet type.
sl : :obj:`list`
List of slices to reconstruct the wavelet coefficients from the
raveled array.
dims : :obj:`tuple`
Shape of the array after the adjoint, but before flattening.
For example, ``x_reshaped = (Op.H * y.ravel()).reshape(Op.dims)``.
dimsd : :obj:`tuple`
Shape of the array after the forward, but before flattening.
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
shape : :obj:`tuple`
Operator shape.
Raises
------
ModuleNotFoundError
If ``pywt`` is not installed
ValueError
If ``wavelet`` does not belong to ``pywt.families``
Notes
-----
The Wavelet operator applies the N-dimensional multilevel Discrete
Wavelet Transform (DWTN) in forward mode and the N-dimensional multilevel
Inverse Discrete Wavelet Transform (IDWTN) in adjoint mode.
"""
def __init__(
self,
dims: InputDimsLike,
axes: InputDimsLike = (-3, -2, -1),
wavelet: str = "haar",
level: int = 1,
dtype: DTypeLike = "float64",
name: str = "D",
) -> None:
if pywt_message is not None:
raise ModuleNotFoundError(pywt_message)
_checkwavelet(wavelet)
# define padding for length to be power of 2
ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2**level) for ax in axes]
pad = [(0, 0)] * len(dims)
for i, ax in enumerate(axes):
pad[ax] = (0, ndimpow2[i] - dims[ax])
self.pad = Pad(dims, pad)
self.axes = axes
dimsd = list(dims)
for i, ax in enumerate(axes):
dimsd[ax] = ndimpow2[i]
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
# apply transform once again to find out slices
_, self.sl = pywt.coeffs_to_array(
pywt.wavedecn(
np.ones(self.dimsd),
wavelet=wavelet,
level=level,
mode="periodization",
axes=self.axes,
),
axes=self.axes,
)
self.wavelet = wavelet
self.waveletadj = _adjointwavelet(wavelet)
self.level = level
def _matvec(self, x: NDArray) -> NDArray:
x = self.pad.matvec(x)
x = np.reshape(x, self.dimsd)
y = pywt.coeffs_to_array(
pywt.wavedecn(
x,
wavelet=self.wavelet,
level=self.level,
mode="periodization",
axes=self.axes,
),
axes=(self.axes),
)[0]
return y.ravel()
def _rmatvec(self, x: NDArray) -> NDArray:
x = np.reshape(x, self.dimsd)
x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn")
y = pywt.waverecn(
x, wavelet=self.waveletadj, mode="periodization", axes=self.axes
)
y = self.pad.rmatvec(y.ravel())
return y