GPU / TPU Support#
Overview#
From v1.12.0, PyLops supports computations on GPUs powered by
CuPy (cupy-cudaXX>=13.0.0).
This library must be installed before PyLops is installed.
From v2.3.0, PyLops supports also computations on GPUs/TPUs powered by
JAX.
This library must be installed before PyLops is installed.
Note
Set environment variables CUPY_PYLOPS=0 and/or JAX_PYLOPS=0 to force PyLops to ignore
cupy and jax backends. This can be also used if a previous version of cupy
or jax is installed in your system, otherwise you will get an error when importing PyLops.
Apart from a few exceptions, all operators and solvers in PyLops can
seamlessly work with numpy arrays on CPU as well as with cupy/jax arrays
on GPU. For CuPy, users simply need to consistently create operators and
provide data vectors to the solvers, e.g., when using
pylops.MatrixMult the input matrix must be a
cupy array if the data provided to a solver is also cupy array.
For JAX, apart from following the same procedure described for CuPy, the PyLops operator must
be also wrapped into a pylops.JaxOperator.
In the following, we provide a list of methods in pylops.LinearOperator with their current status (available on CPU,
GPU with CuPy, and GPU with JAX):
Operator/method |
CPU |
GPU with CuPy |
GPU/TPU with JAX |
|---|---|---|---|
|
β |
π΄ |
π΄ |
|
β |
β |
β |
|
β |
β |
β |
|
β |
π΄ |
π΄ |
|
β |
β |
β |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
Similarly, we provide a list of operators with their current status.
Basic operators:
Operator/method |
CPU |
GPU with CuPy |
GPU/TPU with JAX |
|---|---|---|---|
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
π΄ |
π΄ |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
Smoothing and derivatives:
Operator/method |
CPU |
GPU with CuPy |
GPU/TPU with JAX |
|---|---|---|---|
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
Signal processing:
Operator/method |
CPU |
GPU with CuPy |
GPU/TPU with JAX |
|---|---|---|---|
β |
β |
β οΈ |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
β |
|
β |
β |
π΄ |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
π΄ |
|
β |
β |
β |
Wave-Equation processing
Operator/method |
CPU |
GPU with CuPy |
GPU/TPU with JAX |
|---|---|---|---|
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β |
|
β |
π΄ |
π΄ |
|
β |
π΄ |
π΄ |
Geophysical subsurface characterization:
Operator/method |
CPU |
GPU with CuPy |
GPU/TPU with JAX |
|---|---|---|---|
β |
β |
β |
|
β |
β |
β |
|
β |
β |
β οΈ |
|
β |
β |
β οΈ |
Warning
1. The JAX backend of the pylops.signalprocessing.Convolve1D operator
currently works only with 1d-arrays due to a different behaviour of
scipy.signal.convolve and jax.scipy.signal.convolve with
nd-arrays.
2. The JAX backend of the pylops.avo.prestack.PrestackLinearModelling
operator currently works only with explicit=True due to the same issue as
in point 1 for the pylops.signalprocessing.Convolve1D operator employed
when explicit=False.
Example#
Finally, letβs briefly look at an example. First we write a code snippet using
numpy arrays which PyLops will run on your CPU:
ny, nx = 400, 400
G = np.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = np.ones(nx, dtype=np.float32)
Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y
Now we write a code snippet using cupy arrays which PyLops will run on
your GPU:
ny, nx = 400, 400
G = cp.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = cp.ones(nx, dtype=np.float32)
Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y
The code is almost unchanged apart from the fact that we now use cupy arrays,
PyLops will figure this out.
Similarly, we write a code snippet using jax arrays which PyLops will run on
your GPU/TPU:
ny, nx = 400, 400
G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
x = jnp.ones(nx, dtype=np.float32)
Gop = JaxOperator(MatrixMult(G, dtype='float32'))
y = Gop * x
xest = Gop / y
# Adjoint via AD
xadj = Gop.rmatvecad(x, y)
Again, the code is almost unchanged apart from the fact that we now use jax arrays,