Interactive online version: Open In Colab

Implementing new Linear Operators and Dispatch Rules

Implementing new linear operators in CoLA requires specifying a shape, dtype, and matmat functions. Llike with scipy LinearOperator, there are two ways of doing so.

Calling LinearOperator as a constructor

For a one off, a quick and dirty approach is to use the LinearOperator constructor directly. Let’s assume we have some matrix vector multiply which is problem specific and not very generalizable.

[1]:
import cola
import torch

def weird_matmat(x):
    # x of shape (100, d)
    return (x[2]+x[3])*torch.ones(5,1) - 3*x[3:]

shape = (5,8)
A = cola.ops.LinearOperator(torch.float32, shape, matmat=weird_matmat)
print(A.to_dense())
tensor([[ 0.,  0.,  1., -2.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  1., -3.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  1.,  0., -3.,  0.,  0.],
        [ 0.,  0.,  1.,  1.,  0.,  0., -3.,  0.],
        [ 0.,  0.,  1.,  1.,  0.,  0.,  0., -3.]])

Subclassing LinearOperator

For a more extendible approach, and one that can leverage dispatch rules, we recommend subclassing LinearOperator and defining the __init__ calling super().__init__(dtype,shape)and defining a new matmat method.

For example, lets define a Diagonal LinearOperator below:

[2]:
class MyDiagonal(cola.ops.LinearOperator):
    """ Diagonal LinearOperator. O(n) time and space matmuls"""
    def __init__(self, diag):
        super().__init__(dtype=diag.dtype, shape=(len(diag), ) * 2)
        self.diag = diag

    def _matmat(self, X):
        return self.diag[:, None] * X

    def __str__(self):
        return f"MyDiagonal({self.diag})"
[3]:
import jax.numpy as jnp

A = MyDiagonal(jnp.arange(1,5))
print(A.to_dense())
[[1 0 0 0]
 [0 2 0 0]
 [0 0 3 0]
 [0 0 0 4]]

Defining New Dispatch Rules

Implementing new dispatch rules for existing functions is easy, simply wrap the new function with a cola.dispatch decorator and define the functionality for a given LinearOperator and a given Algorithm. If you don’t have a specific algorithm to use (as below) simply ensure that alg: Algorithm is part of the function’s signature.

Here we will extend inverse for the MyDiagonal object.

[4]:
from cola.linalg import inv
from cola.linalg.algorithm_base import Algorithm

@cola.dispatch
def inv(A: MyDiagonal, alg: Algorithm):
    print("Called my inverse")
    return MyDiagonal(1 / A.diag)


A = MyDiagonal(torch.arange(1, 500000))
invA = inv(A)
print(invA)
Called my inverse
MyDiagonal(tensor([1.0000e+00, 5.0000e-01, 3.3333e-01,  ..., 2.0000e-06, 2.0000e-06,
        2.0000e-06]))

You can also override existing functionality.

[5]:
@cola.dispatch
def inv(A: cola.ops.Dense, alg: Algorithm):
    print("I overrided the dense inverse")
    return cola.ops.Dense(torch.linalg.inv(A.to_dense()))


A = cola.ops.Dense(torch.arange(1, 5).reshape(2, 2).float())
invA = inv(A)
print(invA.to_dense())
I overrided the dense inverse
tensor([[-2.0000,  1.0000],
        [ 1.5000, -0.5000]])

We can also implement entirely new linear algebra functions on existing objects, just make sure to have a base case.

For example, lets define a rowsum function that sums the rows of a LinearOperator.

[6]:
@cola.dispatch
def rowsum(A: cola.ops.LinearOperator):
    print("dispatched base case")
    return A @ A.xnp.ones(A.shape[:1], dtype=A.dtype, device=A.device)


@cola.dispatch
def rowsum(A: MyDiagonal):
    print("dispatched on MyDiagonal")
    return A.diag


A = MyDiagonal(torch.arange(5))
print(rowsum(A))
dispatched on MyDiagonal
tensor([0, 1, 2, 3, 4])
[7]:
print(rowsum(cola.ops.Dense(torch.arange(4).reshape(2, 2))))
dispatched base case
tensor([1, 5])

TODO: Add example of parametric dispatch for woodbury formula