GPU / TPU Support#

Overview#

From v1.12.0, PyLops supports computations on GPUs powered by CuPy (cupy-cudaXX>=13.0.0). This library must be installed before PyLops is installed.

From v2.3.0, PyLops supports also computations on GPUs/TPUs powered by JAX. This library must be installed before PyLops is installed.

Note

Set environment variables CUPY_PYLOPS=0 and/or JAX_PYLOPS=0 to force PyLops to ignore cupy and jax backends. This can be also used if a previous version of cupy or jax is installed in your system, otherwise you will get an error when importing PyLops.

Apart from a few exceptions, all operators and solvers in PyLops can seamlessly work with numpy arrays on CPU as well as with cupy/jax arrays on GPU. For CuPy, users simply need to consistently create operators and provide data vectors to the solvers, e.g., when using pylops.MatrixMult the input matrix must be a cupy array if the data provided to a solver is also cupy array. For JAX, apart from following the same procedure described for CuPy, the PyLops operator must be also wrapped into a pylops.JaxOperator.

In the following, we provide a list of methods in pylops.LinearOperator with their current status (available on CPU, GPU with CuPy, and GPU with JAX):

Operator/method

CPU

GPU with CuPy

GPU/TPU with JAX

pylops.LinearOperator.cond

βœ…

πŸ”΄

πŸ”΄

pylops.LinearOperator.conj

βœ…

βœ…

βœ…

pylops.LinearOperator.div

βœ…

βœ…

βœ…

pylops.LinearOperator.eigs

βœ…

πŸ”΄

πŸ”΄

pylops.LinearOperator.todense

βœ…

βœ…

βœ…

pylops.LinearOperator.tosparse

βœ…

πŸ”΄

πŸ”΄

pylops.LinearOperator.trace

βœ…

πŸ”΄

πŸ”΄

Similarly, we provide a list of operators with their current status.

Basic operators:

Operator/method

CPU

GPU with CuPy

GPU/TPU with JAX

pylops.basicoperators.MatrixMult

βœ…

βœ…

βœ…

pylops.basicoperators.Identity

βœ…

βœ…

βœ…

pylops.basicoperators.Zero

βœ…

βœ…

βœ…

pylops.basicoperators.Diagonal

βœ…

βœ…

βœ…

pylops.basicoperators.Transpose

βœ…

βœ…

βœ…

pylops.basicoperators.Flip

βœ…

βœ…

βœ…

pylops.basicoperators.Roll

βœ…

βœ…

βœ…

pylops.basicoperators.Pad

βœ…

βœ…

βœ…

pylops.basicoperators.Sum

βœ…

βœ…

βœ…

pylops.basicoperators.Symmetrize

βœ…

βœ…

βœ…

pylops.basicoperators.Restriction

βœ…

βœ…

βœ…

pylops.basicoperators.Regression

βœ…

βœ…

βœ…

pylops.basicoperators.LinearRegression

βœ…

βœ…

βœ…

pylops.basicoperators.CausalIntegration

βœ…

βœ…

βœ…

pylops.basicoperators.Spread

βœ…

πŸ”΄

πŸ”΄

pylops.basicoperators.VStack

βœ…

βœ…

βœ…

pylops.basicoperators.HStack

βœ…

βœ…

βœ…

pylops.basicoperators.Block

βœ…

βœ…

βœ…

pylops.basicoperators.BlockDiag

βœ…

βœ…

βœ…

Smoothing and derivatives:

Operator/method

CPU

GPU with CuPy

GPU/TPU with JAX

pylops.basicoperators.FirstDerivative

βœ…

βœ…

βœ…

pylops.basicoperators.SecondDerivative

βœ…

βœ…

βœ…

pylops.basicoperators.Laplacian

βœ…

βœ…

βœ…

pylops.basicoperators.Gradient

βœ…

βœ…

βœ…

pylops.basicoperators.FirstDirectionalDerivative

βœ…

βœ…

βœ…

pylops.basicoperators.SecondDirectionalDerivative

βœ…

βœ…

βœ…

Signal processing:

Operator/method

CPU

GPU with CuPy

GPU/TPU with JAX

pylops.signalprocessing.Convolve1D

βœ…

βœ…

⚠️

pylops.signalprocessing.Convolve2D

βœ…

βœ…

βœ…

pylops.signalprocessing.ConvolveND

βœ…

βœ…

βœ…

pylops.signalprocessing.NonStationaryConvolve1D

βœ…

βœ…

βœ…

pylops.signalprocessing.NonStationaryFilters1D

βœ…

βœ…

βœ…

pylops.signalprocessing.NonStationaryConvolve2D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.NonStationaryFilters2D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Interp

βœ…

βœ…

βœ…

pylops.signalprocessing.Bilinear

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.FFT

βœ…

βœ…

βœ…

pylops.signalprocessing.FFT2D

βœ…

βœ…

βœ…

pylops.signalprocessing.FFTND

βœ…

βœ…

βœ…

pylops.signalprocessing.Shift

βœ…

βœ…

βœ…

pylops.signalprocessing.DWT

βœ…

πŸ”΄

πŸ”΄

pylops.signalprocessing.DWT2D

βœ…

πŸ”΄

πŸ”΄

pylops.signalprocessing.DCT

βœ…

πŸ”΄

πŸ”΄

pylops.signalprocessing.Seislet

βœ…

πŸ”΄

πŸ”΄

pylops.signalprocessing.Radon2D

βœ…

πŸ”΄

πŸ”΄

pylops.signalprocessing.Radon3D

βœ…

πŸ”΄

πŸ”΄

pylops.signalprocessing.ChirpRadon2D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.ChirpRadon3D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Sliding1D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Sliding2D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Sliding3D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Patch2D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Patch3D

βœ…

βœ…

πŸ”΄

pylops.signalprocessing.Fredholm1

βœ…

βœ…

βœ…

Wave-Equation processing

Operator/method

CPU

GPU with CuPy

GPU/TPU with JAX

pylops.avo.avo.PressureToVelocity

βœ…

βœ…

βœ…

pylops.avo.avo.UpDownComposition2D

βœ…

βœ…

βœ…

pylops.avo.avo.UpDownComposition3D

βœ…

βœ…

βœ…

pylops.avo.avo.BlendingContinuous

βœ…

βœ…

βœ…

pylops.avo.avo.BlendingGroup

βœ…

βœ…

βœ…

pylops.avo.avo.BlendingHalf

βœ…

βœ…

βœ…

pylops.avo.avo.MDC

βœ…

βœ…

βœ…

pylops.avo.avo.Kirchhoff

βœ…

πŸ”΄

πŸ”΄

pylops.avo.avo.AcousticWave2D

βœ…

πŸ”΄

πŸ”΄

Geophysical subsurface characterization:

Operator/method

CPU

GPU with CuPy

GPU/TPU with JAX

pylops.avo.avo.AVOLinearModelling

βœ…

βœ…

βœ…

pylops.avo.poststack.PoststackLinearModelling

βœ…

βœ…

βœ…

pylops.avo.prestack.PrestackLinearModelling

βœ…

βœ…

⚠️

pylops.avo.prestack.PrestackWaveletModelling

βœ…

βœ…

⚠️

Warning

1. The JAX backend of the pylops.signalprocessing.Convolve1D operator currently works only with 1d-arrays due to a different behaviour of scipy.signal.convolve and jax.scipy.signal.convolve with nd-arrays.

2. The JAX backend of the pylops.avo.prestack.PrestackLinearModelling operator currently works only with explicit=True due to the same issue as in point 1 for the pylops.signalprocessing.Convolve1D operator employed when explicit=False.

Example#

Finally, let’s briefly look at an example. First we write a code snippet using numpy arrays which PyLops will run on your CPU:

ny, nx = 400, 400
G = np.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = np.ones(nx, dtype=np.float32)

Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y

Now we write a code snippet using cupy arrays which PyLops will run on your GPU:

ny, nx = 400, 400
G = cp.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = cp.ones(nx, dtype=np.float32)

Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y

The code is almost unchanged apart from the fact that we now use cupy arrays, PyLops will figure this out.

Similarly, we write a code snippet using jax arrays which PyLops will run on your GPU/TPU:

ny, nx = 400, 400
G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
x = jnp.ones(nx, dtype=np.float32)

Gop = JaxOperator(MatrixMult(G, dtype='float32'))
y = Gop * x
xest = Gop / y

# Adjoint via AD
xadj = Gop.rmatvecad(x, y)

Again, the code is almost unchanged apart from the fact that we now use jax arrays,

Note

More examples for the CuPy and JAX backends be found here and here.