Interactive online version: Open In Colab

Accessing Lower Level Algorithms (CG, Lanczos, Arnoldi, etc)

[1]:
# autoreload
%load_ext autoreload
%autoreload 2

import time
import torch
import cola
import jax.numpy as jnp
import warnings

warnings.filterwarnings('ignore')

Let’s construct a simple example matrix with a rapidly decaying spectrum, such as an RBF kernel.

[2]:
from cola.linalg.eig.power_iteration import power_iteration
N = 3_000
x = torch.linspace(-1, 1, N)
C = cola.lazify(torch.exp(-(x[None] - x[:, None])**2 / 2))
_, eigmax, _ = power_iteration(C, tol=1e-2)
C = C + 1e-4 * eigmax * cola.ops.I_like(C)

We can run CG, GMRES to perform solves, or Lanczos and Arnoldi for decompositions to compute \(f(A)v\) for a vector \(v\) or for estimating the extremal eigenvalues.

[3]:
y = torch.randn(C.shape[-1])
sol, info = cola.CG(tol=1e-4, pbar=True)(C, y)
sol1, info1 = cola.GMRES(tol=1e-4, pbar=True)(C, y)
Q1, T, info2 = cola.Lanczos(pbar=True, tol=1e-4, max_iters=1000)(C)
Q2, H, info3 = cola.Arnoldi(pbar=True, tol=1e-4, max_iters=1000)(C)
WARNING:root:Non keyed randn used. To be deprecated soon.
WARNING:root:Non keyed randn used. To be deprecated soon.

Information on the convergence criteria, number of iterations used, and the time per iteration is returned in the info dictionary. We can also use pbar=True to track the convergence with a tqdm progressbar.

[4]:
print(info)
{'iterations': 9, 'errors': array([2.55869064e+01, 1.55503988e+00, 7.40348101e-02, 4.77714449e-01,
       9.47988685e-03, 1.06564164e-01, 1.37272538e-04, 1.37272538e-04]), 'iteration_time': 0.004042996300591363}

Let’s plot the convergence criteria for these different algorithms:

[5]:
eigs = cola.eig(cola.SelfAdjoint(C), k=C.shape[0])[0]

import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20
algorithms = [('CG', info), ('Lanczos', info2), ('Arnoldi', info3)]
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
for i, (name, info) in enumerate(algorithms):
    axs[i].plot(info['errors'])
    axs[i].set_yscale('log')
    axs[i].set_title(f'{name} Convergence')
    axs[i].set_xlabel('Iteration')
    axs[i].set_ylabel('Stopping Criteria')

axs[3].plot(eigs)
axs[3].set_yscale('log')
axs[3].set_title('Spectrum of the Matrix')
axs[3].set_xlabel('Index')
axs[3].set_ylabel('Eigenvalue')

plt.tight_layout()
plt.show()
WARNING:root:Non keyed randn used. To be deprecated soon.
../_images/notebooks_01_algorithms_9_1.png

For CG, the residual \(\|Ax-b\|\) is the stopping criteria, whereas for Lanczos and Arnoldi it is diagonal entries in the Tridiagonal and Upper Hessenberg matrices respectively.

One tricky thing is that if you jit the algorithm or a function containing the algorithm, then the info dict will no longer be populated as jitted methods must always return arrays of the same shape. (though the progress bar will still be updated)

[6]:
import jax

C_jax = cola.lazify(jnp.array(C.to_dense()))
y_jax = jnp.array(y)


def mycg(C, y):
    return cola.CG(tol=1e-4, pbar=True)(C, y)


solj, infoj = jax.jit(mycg)(C_jax, y_jax)
[7]:
print(infoj)
{}