Source code for pylops.signalprocessing.DWT2D

import logging
from math import log, ceil

import numpy as np
from pylops import LinearOperator
from pylops.basicoperators import Pad
from .DWT import _checkwavelet, _adjointwavelet

    import pywt
except ModuleNotFoundError:
    pywt = None

logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING)

[docs]class DWT2D(LinearOperator): """Two dimensional Wavelet operator. Apply 2D-Wavelet Transform along two directions ``dirs`` of a multi-dimensional array of size ``dims``. Note that the Wavelet operator is an overload of the ``pywt`` implementation of the wavelet transform. Refer to for a detailed description of the input parameters. Parameters ---------- dims : :obj:`tuple` Number of samples for each dimension dirs : :obj:`tuple`, optional Direction along which DWT2D is applied. wavelet : :obj:`str`, optional Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for a list of available wavelets. level : :obj:`int`, optional Number of scaling levels (must be >=0). dtype : :obj:`str`, optional Type of elements in input array. Attributes ---------- shape : :obj:`tuple` Operator shape explicit : :obj:`bool` Operator contains a matrix that can be solved explicitly (True) or not (False) Raises ------ ModuleNotFoundError If ``pywt`` is not installed ValueError If ``wavelet`` does not belong to ``pywt.families`` Notes ----- The Wavelet operator applies the 2-dimensional multilevel Discrete Wavelet Transform (DWT2) in forward mode and the 2-dimensional multilevel Inverse Discrete Wavelet Transform (IDWT2) in adjoint mode. """ def __init__(self, dims, dirs=(0, 1), wavelet='haar', level=1, dtype='float64'): if pywt is None: raise ModuleNotFoundError('The wavelet operator requires ' 'the pywt package t be installed. ' 'Run "pip install PyWavelets" or ' '"conda install pywavelets".') _checkwavelet(wavelet) # define padding for length to be power of 2 ndimpow2 = [max(2 ** ceil(log(dims[dir], 2)), 2 ** level) for dir in dirs] pad = [(0, 0)] * len(dims) for i, dir in enumerate(dirs): pad[dir] = (0, ndimpow2[i] - dims[dir]) self.pad = Pad(dims, pad) self.dims = dims self.dirs = dirs self.dimsd = list(dims) for i, dir in enumerate(dirs): self.dimsd[dir] = ndimpow2[i] # apply transform once again to find out slices _, = \ pywt.coeffs_to_array(pywt.wavedec2(np.ones(self.dimsd), wavelet=wavelet, level=level, mode='periodization', axes=self.dirs), axes=self.dirs) self.wavelet = wavelet self.waveletadj = _adjointwavelet(wavelet) self.level = level self.shape = (int(, int( self.dtype = np.dtype(dtype) self.explicit = False def _matvec(self, x): x = self.pad.matvec(x) x = np.reshape(x, self.dimsd) y = pywt.coeffs_to_array(pywt.wavedec2(x, wavelet=self.wavelet, level=self.level, mode='periodization', axes=self.dirs), axes=(self.dirs))[0] return y.ravel() def _rmatvec(self, x): x = np.reshape(x, self.dimsd) x = pywt.array_to_coeffs(x,, output_format='wavedec2') y = pywt.waverec2(x, wavelet=self.waveletadj, mode='periodization', axes=self.dirs) y = self.pad.rmatvec(y.ravel()) return y