import logging
from math import log, ceil
import numpy as np
from pylops import LinearOperator
from pylops.basicoperators import Pad
from .DWT import _checkwavelet, _adjointwavelet
try:
import pywt
except ModuleNotFoundError:
pywt = None
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING)
[docs]class DWT2D(LinearOperator):
"""Two dimensional Wavelet operator.
Apply 2D-Wavelet Transform along two directions ``dirs`` of a
multi-dimensional array of size ``dims``.
Note that the Wavelet operator is an overload of the ``pywt``
implementation of the wavelet transform. Refer to
https://pywavelets.readthedocs.io for a detailed description of the
input parameters.
Parameters
----------
dims : :obj:`tuple`
Number of samples for each dimension
dirs : :obj:`tuple`, optional
Direction along which DWT2D is applied.
wavelet : :obj:`str`, optional
Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for
a list of available wavelets.
level : :obj:`int`, optional
Number of scaling levels (must be >=0).
dtype : :obj:`str`, optional
Type of elements in input array.
Attributes
----------
shape : :obj:`tuple`
Operator shape
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly
(True) or not (False)
Raises
------
ModuleNotFoundError
If ``pywt`` is not installed
ValueError
If ``wavelet`` does not belong to ``pywt.families``
Notes
-----
The Wavelet operator applies the 2-dimensional multilevel Discrete
Wavelet Transform (DWT2) in forward mode and the 2-dimensional multilevel
Inverse Discrete Wavelet Transform (IDWT2) in adjoint mode.
"""
def __init__(self, dims, dirs=(0, 1), wavelet='haar',
level=1, dtype='float64'):
if pywt is None:
raise ModuleNotFoundError('The wavelet operator requires '
'the pywt package t be installed. '
'Run "pip install PyWavelets" or '
'"conda install pywavelets".')
_checkwavelet(wavelet)
# define padding for length to be power of 2
ndimpow2 = [max(2 ** ceil(log(dims[dir], 2)), 2 ** level)
for dir in dirs]
pad = [(0, 0)] * len(dims)
for i, dir in enumerate(dirs):
pad[dir] = (0, ndimpow2[i] - dims[dir])
self.pad = Pad(dims, pad)
self.dims = dims
self.dirs = dirs
self.dimsd = list(dims)
for i, dir in enumerate(dirs):
self.dimsd[dir] = ndimpow2[i]
# apply transform once again to find out slices
_, self.sl = \
pywt.coeffs_to_array(pywt.wavedec2(np.ones(self.dimsd),
wavelet=wavelet,
level=level,
mode='periodization',
axes=self.dirs),
axes=self.dirs)
self.wavelet = wavelet
self.waveletadj = _adjointwavelet(wavelet)
self.level = level
self.shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
self.dtype = np.dtype(dtype)
self.explicit = False
def _matvec(self, x):
x = self.pad.matvec(x)
x = np.reshape(x, self.dimsd)
y = pywt.coeffs_to_array(pywt.wavedec2(x, wavelet=self.wavelet,
level=self.level,
mode='periodization',
axes=self.dirs),
axes=(self.dirs))[0]
return y.ravel()
def _rmatvec(self, x):
x = np.reshape(x, self.dimsd)
x = pywt.array_to_coeffs(x, self.sl, output_format='wavedec2')
y = pywt.waverec2(x, wavelet=self.waveletadj, mode='periodization',
axes=self.dirs)
y = self.pad.rmatvec(y.ravel())
return y