{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# jit, vmap, grad, and pytrees" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All `LinearOperator` objects are native jax and pytorch pytrees.\n", "\n", "This means that we can vmap over them, jit functions with them, as well as other operations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example: Tree Map\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Jit example (in jax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's choose jitting a function involving matrix square roots." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:jax._src.xla_bridge:CUDA backend failed to initialize: Found CUDA version 12000, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "from jax import jit\n", "import numpy as np\n", "import jax.numpy as jnp\n", "# set cpu\n", "import jax\n", "jax.config.update('jax_platform_name', 'cpu')\n", "import cola\n", "\n", "# construct a linear operator\n", "A = jnp.array(np.random.randn(2, 2))\n", "B = cola.SelfAdjoint(cola.lazify(A.T@A+ 1e-4*jnp.eye(2)))\n", "D = cola.SelfAdjoint(cola.ops.Diagonal(jnp.array([3.,0.2, 1.])))\n", "K = cola.kron(B,D)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets verify that CoLA indeed computes the sqrts on this matrix" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "error: 6.3414245e-07\n" ] } ], "source": [ "from cola import Auto\n", "v = jnp.array(np.random.randn(6))\n", "K_half_v = cola.sqrt(K)@v\n", "Kv = cola.sqrt(K)@K_half_v\n", "print(\"error:\",jnp.linalg.norm(Kv - K@v))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's jit a function with a `LinearOperator` as an argument" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0.13759953+0.j -0.94336665+0.j 0.3768798 +0.j -0.16725242+0.j\n", " -0.26825097+0.j 0.15196525+0.j]\n", "[ 0.13759911+0.j -0.94336635+0.j 0.3768798 +0.j -0.16725181+0.j\n", " -0.26825133+0.j 0.15196525+0.j]\n" ] } ], "source": [ "@jit\n", "def sqrt_mvm(K, v):\n", " return cola.sqrt(K, Auto(tol=1e-4))@v\n", "\n", "print(sqrt_mvm(K,v))\n", "print(sqrt_mvm(4*K,v)/2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batched LinearOperator operations using vmap (in pytorch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's consider a function that constructs some linear operators, and a separate function that applies some transformations.\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[4.3080, 2.7321, 1.8498, 0.0000, 0.0000],\n", " [2.7321, 5.7340, 3.7580, 0.0000, 0.0000],\n", " [1.8498, 3.7580, 6.7027, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 4.3080, 2.7321],\n", " [0.0000, 0.0000, 0.0000, 2.7321, 5.7340]])\n", "tensor([[4.3080, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 5.7340, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 6.7027, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 4.3080, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 5.7340]])\n" ] } ], "source": [ "import cola\n", "import torch\n", "import numpy as np\n", "\n", "def construct_complicated_linops(X):\n", " X = cola.lazify(X)\n", " Y = X@X.T\n", " Y = cola.PSD(Y+cola.ops.I_like(Y))\n", " D = cola.PSD(cola.ops.Diagonal(torch.linspace(0.1,1,2)))\n", " W = cola.ops.BlockDiag(Y,D, multiplicities=[2, 1])\n", " diag_W = cola.diag(W)\n", " return W, cola.PSD(cola.ops.Diagonal(diag_W))\n", "\n", "W,diag_W = construct_complicated_linops(torch.randn(3,3))\n", "print(W[:5,:5].to_dense())\n", "print(diag_W[:5,:5].to_dense())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the example, let's consider a function that uses the diagonal of `W` as a symmetric preconditioner, used explicitly (rather than supplied as an argument to inverse)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def perform_operations(W,D,v):\n", " P = cola.pow(D, -0.5) # D^{-1/2}\n", " y = P@cola.inv(P@W@P,Auto(tol=1e-4))@P@v\n", " return y " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now suppose that we want to perform this operation over a batch of LinearOperators, each with different data.\n", "\n", "First we can vmap over the functions constructing the LinearOperators:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from torch.func import vmap\n", "bW, bD = vmap(construct_complicated_linops)(torch.randn(3,5, 5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that the objects are the same types and shapes," ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(12, 12) \n", "(12, 12) \n" ] } ], "source": [ "print(bW.shape, type(bW))\n", "print(bD.shape, type(bD))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However the data the makes up these objects now has a batch dimension:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 12])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bD.diag.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In general these objects should not be used except precisely in conjunction with a function which is vmapped over a LinearOperator input, as shown below with perform_operations." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 12])\n" ] } ], "source": [ "all_outs = torch.func.vmap(perform_operations)(bW, bD, torch.randn(3,bW.shape[0]))\n", "print(all_outs.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "🚧 Note: Not all LinearOperators with pytorch backend support vmap 🚧\n", "\n", "For example kronecker:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "raised exception: Batching rule not implemented for aten::moveaxis.int; the fallback path doesn't work on out= or view ops.\n" ] } ], "source": [ "def get_entries(M):\n", " return M[:5,:5].to_dense()\n", "\n", "try:\n", " vmap(get_entries)(vmap(cola.kron)(bW, bD))\n", "except RuntimeError as e:\n", " print(\"raised exception:\", e)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 4 }