21. JAX Operator#

This tutorial is aimed at introducing the pylops.JaxOperator operator. This represents the entry-point to the JAX backend of PyLops.

More specifically, by wrapping any of PyLops’ operators into a pylops.JaxOperator one can:

  • apply forward, adjoint and use any of PyLops solver with JAX arrays;

  • enable automatic differentiation;

  • enable automatic vectorization.

Moreover, both the forward and adjoint are internally just-in-time compiled to enable any further optimization provided by JAX.

In this example we will consider a pylops.MatrixMult operator and showcase how to use it in conjunction with pylops.JaxOperator to enable the different JAX functionalities mentioned above.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import pylops

plt.close("all")
np.random.seed(10)

Let’s start by creating a pylops.MatrixMult operator. We will then perform the dot-test as well as apply the forward and adjoint operations to JAX arrays.

n = 4
G = np.random.normal(0, 1, (n, n)).astype("float32")
Gopjax = pylops.JaxOperator(pylops.MatrixMult(jnp.array(G), dtype="float32"))

# dottest
pylops.utils.dottest(Gopjax, n, n, backend="jax", verb=True, atol=1e-3)

# forward
xjnp = jnp.ones(n, dtype="float32")
yjnp = Gopjax @ xjnp

# adjoint
xadjjnp = Gopjax.H @ yjnp
Dot test passed, v^H(Opu)=10.667926788330078 - u^H(Op^Hv)=10.667926788330078

We can now use one of PyLops solvers to invert the operator

xcgls = pylops.optimization.basic.cgls(
    Gopjax, yjnp, x0=jnp.zeros(n), niter=100, tol=1e-10, show=True
)[0]
print("Inverse: ", xcgls)
CGLS
-----------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.000000e+00     tol = 1.000000e-10      niter = 100
-----------------------------------------------------------------

    Itn          x[0]              r1norm         r2norm
     1        6.9697e-02         4.5138e-01     4.5138e-01
     2        2.3785e-01         2.4584e-01     2.4584e-01
     3        2.6222e-01         2.0971e-01     2.0971e-01
     4        1.0000e+00         5.5010e-06     5.5010e-06
     5        1.0000e+00         1.8921e-06     1.8921e-06

Iterations = 5        Total time (s) = 0.20
-----------------------------------------------------------------

Inverse:  [1.        0.9999997 1.0000001 0.9999986]

Let’s see how we can empower the automatic differentiation capabilities of JAX to obtain the adjoint of our operator without having to implement it. Although in PyLops the adjoint of any of operators is hand-written (and optimized), it may be useful in some cases to quickly implement the forward pass of a new operator and get the adjoint for free. This could be extremely beneficial during the prototyping stage of an operator before embarking in implementing an efficient hand-written adjoint.

xadjjnpad = Gopjax.rmatvecad(xjnp, yjnp)

print("Hand-written Adjoint: ", xadjjnp)
print("AD Adjoint: ", xadjjnpad)
Hand-written Adjoint:  [0.1227006  0.65633595 0.1142952  2.1171641 ]
AD Adjoint:  [0.1227006  0.65633595 0.1142952  2.1171641 ]

And more in general how we can combine any of JAX native operations with a PyLops operator.

def fun(x):
    y = Gopjax(x)
    loss = jnp.sum(y)
    return loss


xgrad = jax.grad(fun)(xjnp)
print("Grad: ", xgrad)
Grad:  [ 0.9921482  0.8488673 -0.6182323  1.7483397]

We turn now our attention to automatic vectorization, which is very useful if we want to apply the same operator to multiple vectors. In PyLops we can easily do so by using the matmat and rmatmat methods, however under the hood what these methods do is to simply run a for…loop and call the corresponding matvec / rmatvec methods multiple times. On the other hand, JAX is able to automatically add a batch axis at the beginning of operator. Moreover, this can be seamlessly combined with jax.jit to further improve performance.

auto_batch_matvec = jax.jit(jax.vmap(Gopjax._matvec))
xs = jnp.stack([xjnp, xjnp])
ys = auto_batch_matvec(xs)

print("Original output: ", yjnp)
print("AV Output 1: ", ys[0])
print("AV Output 1: ", ys[1])
Original output:  [0.49308136 0.27531052 1.4657547  0.73697615]
AV Output 1:  [0.49308136 0.27531052 1.4657547  0.73697615]
AV Output 1:  [0.49308136 0.27531052 1.4657547  0.73697615]

Total running time of the script: (0 minutes 0.537 seconds)

Gallery generated by Sphinx-Gallery