Interactive online version: Open In Colab

jit, vmap, grad, and pytrees

All LinearOperator objects are native jax and pytorch pytrees.

This means that we can vmap over them, jit functions with them, as well as other operations.

Example: Tree Map

Jit example (in jax)

Let’s choose jitting a function involving matrix square roots.

[1]:
from jax import jit
import numpy as np
import jax.numpy as jnp
# set cpu
import jax
jax.config.update('jax_platform_name', 'cpu')
import cola

# construct a linear operator
A = jnp.array(np.random.randn(2, 2))
B = cola.SelfAdjoint(cola.lazify(A.T@A+ 1e-4*jnp.eye(2)))
D = cola.SelfAdjoint(cola.ops.Diagonal(jnp.array([3.,0.2, 1.])))
K = cola.kron(B,D)
WARNING:jax._src.xla_bridge:CUDA backend failed to initialize: Found CUDA version 12000, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Lets verify that CoLA indeed computes the sqrts on this matrix

[4]:
from cola import Auto
v = jnp.array(np.random.randn(6))
K_half_v = cola.sqrt(K)@v
Kv = cola.sqrt(K)@K_half_v
print("error:",jnp.linalg.norm(Kv - K@v))
error: 6.3414245e-07

Now let’s jit a function with a LinearOperator as an argument

[5]:
@jit
def sqrt_mvm(K, v):
    return cola.sqrt(K, Auto(tol=1e-4))@v

print(sqrt_mvm(K,v))
print(sqrt_mvm(4*K,v)/2)
[ 0.13759953+0.j -0.94336665+0.j  0.3768798 +0.j -0.16725242+0.j
 -0.26825097+0.j  0.15196525+0.j]
[ 0.13759911+0.j -0.94336635+0.j  0.3768798 +0.j -0.16725181+0.j
 -0.26825133+0.j  0.15196525+0.j]

Batched LinearOperator operations using vmap (in pytorch)

Let’s consider a function that constructs some linear operators, and a separate function that applies some transformations.

[7]:
import cola
import torch
import numpy as np

def construct_complicated_linops(X):
    X = cola.lazify(X)
    Y = X@X.T
    Y = cola.PSD(Y+cola.ops.I_like(Y))
    D = cola.PSD(cola.ops.Diagonal(torch.linspace(0.1,1,2)))
    W = cola.ops.BlockDiag(Y,D, multiplicities=[2, 1])
    diag_W = cola.diag(W)
    return W, cola.PSD(cola.ops.Diagonal(diag_W))

W,diag_W = construct_complicated_linops(torch.randn(3,3))
print(W[:5,:5].to_dense())
print(diag_W[:5,:5].to_dense())
tensor([[4.3080, 2.7321, 1.8498, 0.0000, 0.0000],
        [2.7321, 5.7340, 3.7580, 0.0000, 0.0000],
        [1.8498, 3.7580, 6.7027, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 4.3080, 2.7321],
        [0.0000, 0.0000, 0.0000, 2.7321, 5.7340]])
tensor([[4.3080, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 5.7340, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 6.7027, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 4.3080, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 5.7340]])

For the example, let’s consider a function that uses the diagonal of W as a symmetric preconditioner, used explicitly (rather than supplied as an argument to inverse).

[13]:
def perform_operations(W,D,v):
    P = cola.pow(D, -0.5) # D^{-1/2}
    y = P@cola.inv(P@W@P,Auto(tol=1e-4))@P@v
    return y

Now suppose that we want to perform this operation over a batch of LinearOperators, each with different data.

First we can vmap over the functions constructing the LinearOperators:

[14]:
from torch.func import vmap
bW, bD = vmap(construct_complicated_linops)(torch.randn(3,5, 5))

Notice that the objects are the same types and shapes,

[15]:
print(bW.shape, type(bW))
print(bD.shape, type(bD))
(12, 12) <class 'cola.ops.operators.BlockDiag[cola.ops.operators.Sum[cola.ops.operators.Product[cola.ops.operators.Dense, cola.ops.operators.Dense], cola.ops.operators.Identity], cola.ops.operators.Diagonal]'>
(12, 12) <class 'cola.ops.operators.Diagonal'>

However the data the makes up these objects now has a batch dimension:

[16]:
bD.diag.shape
[16]:
torch.Size([3, 12])

In general these objects should not be used except precisely in conjunction with a function which is vmapped over a LinearOperator input, as shown below with perform_operations.

[17]:
all_outs = torch.func.vmap(perform_operations)(bW, bD, torch.randn(3,bW.shape[0]))
print(all_outs.shape)
torch.Size([3, 12])

🚧 Note: Not all LinearOperators with pytorch backend support vmap 🚧

For example kronecker:

[18]:
def get_entries(M):
    return M[:5,:5].to_dense()

try:
    vmap(get_entries)(vmap(cola.kron)(bW, bD))
except RuntimeError as e:
    print("raised exception:", e)
raised exception: Batching rule not implemented for aten::moveaxis.int; the fallback path doesn't work on out= or view ops.