{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "3ea9b954", "metadata": {}, "source": [ "# Accessing Lower Level Algorithms (CG, Lanczos, Arnoldi, etc)" ] }, { "cell_type": "code", "execution_count": 1, "id": "fee5fa3a", "metadata": {}, "outputs": [], "source": [ "# autoreload\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import time\n", "import torch\n", "import cola\n", "import jax.numpy as jnp\n", "import warnings\n", "\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "id": "4da71ef4", "metadata": {}, "source": [ "Let's construct a simple example matrix with a rapidly decaying spectrum, such as an RBF kernel." ] }, { "cell_type": "code", "execution_count": 2, "id": "09771153", "metadata": {}, "outputs": [], "source": [ "from cola.linalg.eig.power_iteration import power_iteration\n", "N = 3_000\n", "x = torch.linspace(-1, 1, N)\n", "C = cola.lazify(torch.exp(-(x[None] - x[:, None])**2 / 2))\n", "_, eigmax, _ = power_iteration(C, tol=1e-2)\n", "C = C + 1e-4 * eigmax * cola.ops.I_like(C)" ] }, { "cell_type": "markdown", "id": "29fbec4a", "metadata": {}, "source": [ "We can run CG, GMRES to perform solves, or Lanczos and Arnoldi for decompositions to compute $f(A)v$ for a vector $v$ or for estimating the extremal eigenvalues." ] }, { "cell_type": "code", "execution_count": 3, "id": "8d85a96d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "658fbbe35b444254befbfc8c4eb1b4b0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Running body_fun: 0%| | 0/100 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eigs = cola.eig(cola.SelfAdjoint(C), k=C.shape[0])[0]\n", "\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['font.size'] = 20\n", "algorithms = [('CG', info), ('Lanczos', info2), ('Arnoldi', info3)]\n", "fig, axs = plt.subplots(1, 4, figsize=(20, 5))\n", "for i, (name, info) in enumerate(algorithms):\n", " axs[i].plot(info['errors'])\n", " axs[i].set_yscale('log')\n", " axs[i].set_title(f'{name} Convergence')\n", " axs[i].set_xlabel('Iteration')\n", " axs[i].set_ylabel('Stopping Criteria')\n", "\n", "axs[3].plot(eigs)\n", "axs[3].set_yscale('log')\n", "axs[3].set_title('Spectrum of the Matrix')\n", "axs[3].set_xlabel('Index')\n", "axs[3].set_ylabel('Eigenvalue')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "46452352", "metadata": {}, "source": [ "For CG, the residual $\\|Ax-b\\|$ is the stopping criteria, whereas for Lanczos and Arnoldi it is diagonal entries in the Tridiagonal and Upper Hessenberg matrices respectively." ] }, { "cell_type": "markdown", "id": "e445d244", "metadata": {}, "source": [ "One tricky thing is that if you `jit` the algorithm or a function containing the algorithm, then the info dict will no longer be populated as jitted methods must always return arrays of the same shape. (though the progress bar will still be updated)" ] }, { "cell_type": "code", "execution_count": 6, "id": "9c74d2a9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c4d513287e19435fa117a0320fcda7c4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Running body_fun: 0%| | 0/100 [00:00