Source code for pylops.medical.mri

__all__ = [
    "MRI2D",
]

import warnings

import numpy as np

from pylops import LinearOperator
from pylops.basicoperators import Diagonal, Restriction
from pylops.signalprocessing import FFT2D, Bilinear
from pylops.utils.backend import get_module
from pylops.utils.typing import (
    DTypeLike,
    InputDimsLike,
    NDArray,
    Tfftengine_nsm,
    Tmriengine,
    Tmrimask,
)


[docs] class MRI2D(LinearOperator): r"""2D Magnetic Resonance Imaging Apply 2D Magnetic Resonance Imaging operator to obtain a k-space data (i.e., undersampled Fourier representation of the model). Parameters ---------- dims : :obj:`list` or :obj:`int` Number of samples for each dimension. Must be 2-dimensional and of size :math:`n_y \times n_x` mask : :obj:`str` or :obj:`numpy.ndarray` Mask to be applied in the Fourier domain: - :obj:`numpy.ndarray`: a 2-dimensional array of size :math:`n_y \times n_x` with 1 in the selected locations; - ``vertical-reg``: mask with vertical lines (regularly sampled around the second dimension); - ``vertical-uni``: mask with vertical lines (irregularly sampled around the second dimension, with lines drawn from a uniform distribution); - ``radial-reg``: mask with radial lines (regularly sampled around the :math:`-\pi/\pi` angles); - ``radial-uni``: mask with radial lines (irregularly sampled around the :math:`-\pi/\pi` angles, with angles drawn from a uniform distribution); nlines : :obj:`int`, optional Number of lines in the k-space. Not required if ``mask`` is passed as array. perc_center : :obj:`float`, optional Percentage of total lines to retain in the center. Not required if ``mask`` is passed as array. engine : :obj:`str`, optional Engine used for computation (``numpy`` or ``jax``). fft_engine : :obj:`str`, optional Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). dtype : :obj:`str`, optional Type of elements in input array. name : :obj:`str`, optional Name of operator (to be used by :func:`pylops.utils.describe.describe`) **kwargs_fft Arbitrary keyword arguments to be passed to the selected fft method Attributes ---------- mask : :obj:`numpy.ndarray` Mask applied in the Fourier domain. ROp : :obj:`pylops.Restriction` or :obj:`pylops.Diagonal` or :obj:`pylops.signalprocessing.Bilinear` Operator that applies the mask in the Fourier domain. dims : :obj:`tuple` Shape of the array after the adjoint, but before flattening. For example, ``x_reshaped = (Op.H * y.ravel()).reshape(Op.dims)``. dimsd : :obj:`tuple` Shape of the array after the forward, but before flattening. For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``. shape : :obj:`tuple` Operator shape explicit : :obj:`bool` Operator contains a matrix that can be solved explicitly (``True``) or not (``False``) Raises ------ ValueError If ``mask`` is not one of the accepted strings or a numpy array. ValueError If ``fft_engine`` is neither ``numpy``, ``fftw``, nor ``scipy``. ValueError If ``nlines`` or ``perc_center`` are not specified when providing ``mask`` as string. ValueError If ``perc_center`` is greater than 0 when using ``vertical-reg`` mask. Notes ----- The MRI2D operator applies 2-dimensional Fourier transform to the model, followed by a subsampling with a given ``mask``: .. math:: \mathbf{d} = \mathbf{R} \mathbf{F}_{k} \mathbf{m} where :math:`\mathbf{F}_{k}` is the 2-dimensional Fourier transform and :math:`\mathbf{R}` is the mask. """ def __init__( self, dims: InputDimsLike, mask: Tmrimask | NDArray, nlines: int | None = None, perc_center: float | None = 0.1, engine: Tmriengine = "numpy", fft_engine: Tfftengine_nsm = "numpy", dtype: DTypeLike = "complex128", name: str = "M", **kwargs_fft, ) -> None: self._mask_type = mask if isinstance(mask, str) else "mask" self.engine = engine self.fft_engine = fft_engine # Validate inputs if engine == "jax" and fft_engine != "numpy": warnings.warn( "When engine='jax', fft_engine is forced to 'numpy'", stacklevel=2 ) self.fft_engine = "numpy" if isinstance(mask, str) and mask not in ( "vertical-reg", "vertical-uni", "radial-reg", "radial-uni", ): msg = ( "`mask` must be a numpy array, 'vertical-reg', 'vertical-uni', " f"'radial-reg', or 'radial-uni', got {mask}" ) raise ValueError(msg) if self.fft_engine not in ["numpy", "scipy", "mkl_fft"]: msg = "`fft_engine` must be 'numpy', 'scipy', or 'mkl_fft'" raise ValueError(msg) if isinstance(mask, str) and (nlines is None or perc_center is None): msg = ( "`nlines` and `perc_center` must be specified providing mask as string" ) raise ValueError(msg) if isinstance(mask, str) and mask == "vertical-reg" and perc_center > 0.0: msg = "`perc_center` must be 0.0 when using `mask=vertical-reg`" raise ValueError(msg) # Create mask self.mask: NDArray if self._mask_type == "mask": self.mask = mask elif "vertical" in self._mask_type: self.mask = self._vertical_mask( dims, nlines, perc_center, uniform=True if "reg" in self._mask_type else False, ) elif "radial" in self._mask_type: self.mask = self._radial_mask( dims, nlines, uniform=True if "reg" in self._mask_type else False ) # Convert mask to appropriate backend ncp = get_module(self.engine) self.mask = ncp.asarray(self.mask) # Create operator self.ROp, Op = self._calc_op( dims=dims, mask_type=mask if isinstance(mask, str) else "mask", mask=self.mask, fft_engine=self.fft_engine, dtype=dtype, **kwargs_fft, ) super().__init__(Op=Op, name=name) @staticmethod def _vertical_mask( dims: InputDimsLike, nlines: int, perc_center: float, uniform: bool = True ) -> NDArray: """Create vertical mask""" nlines_center = int(perc_center * dims[1]) if (nlines + nlines_center) > dims[1]: msg = ( "`nlines` and `perc_center` produce a number of lines " "greater than the total number of lines of the k-space" f"({nlines + nlines_center}>{dims[1]})" ) raise ValueError(msg) if nlines_center == 0: # No lines from the center if uniform: step = dims[1] // nlines mask = np.arange(0, dims[1], step)[:nlines] else: rng = np.random.default_rng() mask = rng.choice(np.arange(dims[1]), nlines, replace=False) else: # Lines taken from the center istart_center = dims[1] // 2 - nlines_center // 2 iend_center = dims[1] // 2 + nlines_center // 2 + (nlines_center % 2) ilines_center = np.arange(istart_center, iend_center) # Other lines if uniform: nlines_left = nlines // 2 + nlines % 2 step_left = istart_center // nlines_left ilines_left = np.arange(0, istart_center, step_left)[:nlines_left] nlines_right = nlines // 2 step_right = (dims[1] - iend_center) // nlines_left ilines_right = np.arange(iend_center, dims[1], step_right)[ :nlines_right ] mask = np.sort(np.hstack((ilines_left, ilines_center, ilines_right))) else: rng = np.random.default_rng() ilines_other = np.hstack( (np.arange(0, istart_center), np.arange(iend_center, dims[1])) ) ilines_other = rng.choice(ilines_other, nlines, replace=False) mask = np.sort(np.hstack((ilines_center, ilines_other))) return mask @staticmethod def _radial_mask(dims: InputDimsLike, nlines: int, uniform: bool = True) -> NDArray: """Create radial mask""" npoints_per_line = dims[1] - 1 # Define angles if uniform: thetas = np.linspace(0, np.pi, nlines, endpoint=False) else: rng = np.random.default_rng() thetas = rng.uniform(-np.pi, np.pi, nlines) # Create lines lines = [] for theta in thetas: if theta == np.pi / 2: # Create vertical line xline = np.zeros(npoints_per_line) yline = np.linspace( -dims[1] // 2 + 1, dims[1] // 2 - 1, npoints_per_line, endpoint=True ) elif np.tan(theta) >= 0: # Create lines for positive angles xmax = min(dims[1] // 2, (dims[0] // 2) / np.tan(theta)) xline = np.linspace( -xmax, min(xmax, dims[0] // 2 - 1 - (dims[0] + 1) % 2), npoints_per_line, endpoint=True, ) yline = np.tan(theta) * xline else: # Create lines for negative angles xmin = max(-dims[1] // 2 + 1, (dims[0] // 2) / np.tan(theta)) xline = np.linspace( xmin, min(-xmin, dims[0] // 2 - 1), npoints_per_line, endpoint=True ) yline = np.tan(theta) * xline xline, yline = xline + dims[0] // 2, yline + dims[1] // 2 lines.append(np.vstack((xline, yline))) mask = np.concatenate(lines, axis=1) # Remove points beyond domain allowed by Bilinear operator # and duplicate points mask = mask[:, mask[0] < dims[0] - 1] mask = mask[:, mask[1] < dims[1] - 1] mask = np.unique(mask, axis=1) return mask def _matvec(self, x: NDArray) -> NDArray: return super()._matvec(x) def _rmatvec(self, x: NDArray) -> NDArray: return super()._rmatvec(x) @staticmethod def _calc_op( dims: InputDimsLike, mask_type: str, mask: NDArray, fft_engine: Tfftengine_nsm, dtype: DTypeLike, **kwargs_fft, ): """Calculate MRI operator""" fop = FFT2D( dims, nffts=dims, fftshift_after=True, engine=fft_engine, dtype=dtype, **kwargs_fft, ) if mask_type == "mask": rop = Diagonal(mask, dtype=dtype) elif "vertical" in mask_type: rop = Restriction(dims, mask, axis=-1, forceflat=True, dtype=dtype) elif "radial" in mask_type: rop = Bilinear(mask, dims, dtype=dtype) return rop, rop @ fop