Note
Go to the end to download the full example code.
21. JAX Operator#
This tutorial is aimed at introducing the pylops.JaxOperator
operator. This
represents the entry-point to the JAX backend of PyLops.
More specifically, by wrapping any of PyLops’ operators into a
pylops.JaxOperator
one can:
apply forward, adjoint and use any of PyLops solver with JAX arrays;
enable automatic differentiation;
enable automatic vectorization.
Moreover, both the forward and adjoint are internally just-in-time compiled to enable any further optimization provided by JAX.
In this example we will consider a pylops.MatrixMult
operator and
showcase how to use it in conjunction with pylops.JaxOperator
to enable the different JAX functionalities mentioned above.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pylops
plt.close("all")
np.random.seed(10)
Let’s start by creating a pylops.MatrixMult
operator. We will then
perform the dot-test as well as apply the forward and adjoint operations to
JAX arrays.
n = 4
G = np.random.normal(0, 1, (n, n)).astype("float32")
Gopjax = pylops.JaxOperator(pylops.MatrixMult(jnp.array(G), dtype="float32"))
# dottest
pylops.utils.dottest(Gopjax, n, n, backend="jax", verb=True, atol=1e-3)
# forward
xjnp = jnp.ones(n, dtype="float32")
yjnp = Gopjax @ xjnp
# adjoint
xadjjnp = Gopjax.H @ yjnp
Dot test passed, v^H(Opu)=10.667926788330078 - u^H(Op^Hv)=10.667926788330078
We can now use one of PyLops solvers to invert the operator
xcgls = pylops.optimization.basic.cgls(
Gopjax, yjnp, x0=jnp.zeros(n), niter=100, tol=1e-10, show=True
)[0]
print("Inverse: ", xcgls)
CGLS
-----------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.000000e+00 tol = 1.000000e-10 niter = 100
-----------------------------------------------------------------
Itn x[0] r1norm r2norm
1 6.9697e-02 4.5138e-01 4.5138e-01
2 2.3785e-01 2.4584e-01 2.4584e-01
3 2.6222e-01 2.0971e-01 2.0971e-01
4 1.0000e+00 5.5010e-06 5.5010e-06
5 1.0000e+00 1.8921e-06 1.8921e-06
Iterations = 5 Total time (s) = 0.20
-----------------------------------------------------------------
Inverse: [1. 0.9999997 1.0000001 0.9999986]
Let’s see how we can empower the automatic differentiation capabilities of JAX to obtain the adjoint of our operator without having to implement it. Although in PyLops the adjoint of any of operators is hand-written (and optimized), it may be useful in some cases to quickly implement the forward pass of a new operator and get the adjoint for free. This could be extremely beneficial during the prototyping stage of an operator before embarking in implementing an efficient hand-written adjoint.
xadjjnpad = Gopjax.rmatvecad(xjnp, yjnp)
print("Hand-written Adjoint: ", xadjjnp)
print("AD Adjoint: ", xadjjnpad)
Hand-written Adjoint: [0.1227006 0.65633595 0.1142952 2.1171641 ]
AD Adjoint: [0.1227006 0.65633595 0.1142952 2.1171641 ]
And more in general how we can combine any of JAX native operations with a PyLops operator.
Grad: [ 0.9921482 0.8488673 -0.6182323 1.7483397]
We turn now our attention to automatic vectorization, which is very useful
if we want to apply the same operator to multiple vectors. In PyLops we can
easily do so by using the matmat
and rmatmat
methods, however under
the hood what these methods do is to simply run a for…loop and call the
corresponding matvec
/ rmatvec
methods multiple times. On the other
hand, JAX is able to automatically add a batch axis at the beginning of
operator. Moreover, this can be seamlessly combined with jax.jit to
further improve performance.
Original output: [0.49308136 0.27531052 1.4657547 0.73697615]
AV Output 1: [0.49308136 0.27531052 1.4657547 0.73697615]
AV Output 1: [0.49308136 0.27531052 1.4657547 0.73697615]
Total running time of the script: (0 minutes 0.537 seconds)