{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "603d1a3f", "metadata": {}, "source": [ "# Implementing new Linear Operators and Dispatch Rules\n", "\n", "Implementing new linear operators in CoLA requires specifying a `shape`, `dtype`, and `matmat` functions.\n", "Llike with scipy LinearOperator, there are two ways of doing so." ] }, { "attachments": {}, "cell_type": "markdown", "id": "245a54a8", "metadata": {}, "source": [ "### Calling LinearOperator as a constructor\n", "\n", "For a one off, a quick and dirty approach is to use the LinearOperator constructor directly. Let's assume we have some matrix vector multiply which is problem specific and not very generalizable.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "f34717b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0., 0., 1., -2., 0., 0., 0., 0.],\n", " [ 0., 0., 1., 1., -3., 0., 0., 0.],\n", " [ 0., 0., 1., 1., 0., -3., 0., 0.],\n", " [ 0., 0., 1., 1., 0., 0., -3., 0.],\n", " [ 0., 0., 1., 1., 0., 0., 0., -3.]])\n" ] } ], "source": [ "import cola\n", "import torch\n", "\n", "def weird_matmat(x):\n", " # x of shape (100, d)\n", " return (x[2]+x[3])*torch.ones(5,1) - 3*x[3:]\n", "\n", "shape = (5,8)\n", "A = cola.ops.LinearOperator(torch.float32, shape, matmat=weird_matmat)\n", "print(A.to_dense())" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2c411382", "metadata": {}, "source": [ "### Subclassing LinearOperator\n", "\n", "For a more extendible approach, and one that can leverage dispatch rules, we recommend subclassing LinearOperator and defining the `__init__` calling `super().__init__(dtype,shape)`and defining a new `matmat` method.\n", "\n", "For example, lets define a Diagonal LinearOperator below:" ] }, { "cell_type": "code", "execution_count": 2, "id": "1586b1fe", "metadata": {}, "outputs": [], "source": [ "class MyDiagonal(cola.ops.LinearOperator):\n", " \"\"\" Diagonal LinearOperator. O(n) time and space matmuls\"\"\"\n", " def __init__(self, diag):\n", " super().__init__(dtype=diag.dtype, shape=(len(diag), ) * 2)\n", " self.diag = diag\n", "\n", " def _matmat(self, X):\n", " return self.diag[:, None] * X\n", "\n", " def __str__(self):\n", " return f\"MyDiagonal({self.diag})\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "384a1883", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1 0 0 0]\n", " [0 2 0 0]\n", " [0 0 3 0]\n", " [0 0 0 4]]\n" ] } ], "source": [ "import jax.numpy as jnp\n", "\n", "A = MyDiagonal(jnp.arange(1,5))\n", "print(A.to_dense())" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5c1f901b", "metadata": {}, "source": [ "# Defining New Dispatch Rules" ] }, { "attachments": {}, "cell_type": "markdown", "id": "85a0bb6f", "metadata": {}, "source": [ "Implementing new dispatch rules for existing functions is easy, simply wrap the new function with a `cola.dispatch` decorator and define the functionality for a given `LinearOperator` and a given `Algorithm`. If you don't have a specific algorithm to use (as below) simply ensure that `alg: Algorithm` is part of the function's signature.\n", "\n", "Here we will extend inverse for the `MyDiagonal` object." ] }, { "cell_type": "code", "execution_count": 4, "id": "87162025", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Called my inverse\n", "MyDiagonal(tensor([1.0000e+00, 5.0000e-01, 3.3333e-01, ..., 2.0000e-06, 2.0000e-06,\n", " 2.0000e-06]))\n" ] } ], "source": [ "from cola.linalg import inv\n", "from cola.linalg.algorithm_base import Algorithm\n", "\n", "@cola.dispatch\n", "def inv(A: MyDiagonal, alg: Algorithm):\n", " print(\"Called my inverse\")\n", " return MyDiagonal(1 / A.diag)\n", "\n", "\n", "A = MyDiagonal(torch.arange(1, 500000))\n", "invA = inv(A)\n", "print(invA)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8dd45235", "metadata": {}, "source": [ "You can also override existing functionality.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "33b938b8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "I overrided the dense inverse\n", "tensor([[-2.0000, 1.0000],\n", " [ 1.5000, -0.5000]])\n" ] } ], "source": [ "@cola.dispatch\n", "def inv(A: cola.ops.Dense, alg: Algorithm):\n", " print(\"I overrided the dense inverse\")\n", " return cola.ops.Dense(torch.linalg.inv(A.to_dense()))\n", "\n", "\n", "A = cola.ops.Dense(torch.arange(1, 5).reshape(2, 2).float())\n", "invA = inv(A)\n", "print(invA.to_dense())" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7d035e52", "metadata": {}, "source": [ "We can also implement entirely new linear algebra functions on existing objects, just make sure to have a base case.\n", "\n", "For example, lets define a rowsum function that sums the rows of a LinearOperator." ] }, { "cell_type": "code", "execution_count": 6, "id": "dc00f56e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dispatched on MyDiagonal\n", "tensor([0, 1, 2, 3, 4])\n" ] } ], "source": [ "@cola.dispatch\n", "def rowsum(A: cola.ops.LinearOperator):\n", " print(\"dispatched base case\")\n", " return A @ A.xnp.ones(A.shape[:1], dtype=A.dtype, device=A.device)\n", "\n", "\n", "@cola.dispatch\n", "def rowsum(A: MyDiagonal):\n", " print(\"dispatched on MyDiagonal\")\n", " return A.diag\n", "\n", "\n", "A = MyDiagonal(torch.arange(5))\n", "print(rowsum(A))" ] }, { "cell_type": "code", "execution_count": 7, "id": "05572736", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dispatched base case\n", "tensor([1, 5])\n" ] } ], "source": [ "print(rowsum(cola.ops.Dense(torch.arange(4).reshape(2, 2))))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1e55bbf6", "metadata": {}, "source": [ "TODO: Add example of parametric dispatch for woodbury formula" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }