Implementing new Linear Operators and Dispatch Rulesο
Implementing new linear operators in CoLA requires specifying a shape
, dtype
, and matmat
functions. Llike with scipy LinearOperator, there are two ways of doing so.
Calling LinearOperator as a constructorο
For a one off, a quick and dirty approach is to use the LinearOperator constructor directly. Letβs assume we have some matrix vector multiply which is problem specific and not very generalizable.
[1]:
import cola
import torch
def weird_matmat(x):
# x of shape (100, d)
return (x[2]+x[3])*torch.ones(5,1) - 3*x[3:]
shape = (5,8)
A = cola.ops.LinearOperator(torch.float32, shape, matmat=weird_matmat)
print(A.to_dense())
tensor([[ 0., 0., 1., -2., 0., 0., 0., 0.],
[ 0., 0., 1., 1., -3., 0., 0., 0.],
[ 0., 0., 1., 1., 0., -3., 0., 0.],
[ 0., 0., 1., 1., 0., 0., -3., 0.],
[ 0., 0., 1., 1., 0., 0., 0., -3.]])
Subclassing LinearOperatorο
For a more extendible approach, and one that can leverage dispatch rules, we recommend subclassing LinearOperator and defining the __init__
calling super().__init__(dtype,shape)
and defining a new matmat
method.
For example, lets define a Diagonal LinearOperator below:
[2]:
class MyDiagonal(cola.ops.LinearOperator):
""" Diagonal LinearOperator. O(n) time and space matmuls"""
def __init__(self, diag):
super().__init__(dtype=diag.dtype, shape=(len(diag), ) * 2)
self.diag = diag
def _matmat(self, X):
return self.diag[:, None] * X
def __str__(self):
return f"MyDiagonal({self.diag})"
[3]:
import jax.numpy as jnp
A = MyDiagonal(jnp.arange(1,5))
print(A.to_dense())
[[1 0 0 0]
[0 2 0 0]
[0 0 3 0]
[0 0 0 4]]
Defining New Dispatch Rulesο
Implementing new dispatch rules for existing functions is easy, simply wrap the new function with a cola.dispatch
decorator and define the functionality for a given LinearOperator
and a given Algorithm
. If you donβt have a specific algorithm to use (as below) simply ensure that alg: Algorithm
is part of the functionβs signature.
Here we will extend inverse for the MyDiagonal
object.
[4]:
from cola.linalg import inv
from cola.linalg.algorithm_base import Algorithm
@cola.dispatch
def inv(A: MyDiagonal, alg: Algorithm):
print("Called my inverse")
return MyDiagonal(1 / A.diag)
A = MyDiagonal(torch.arange(1, 500000))
invA = inv(A)
print(invA)
Called my inverse
MyDiagonal(tensor([1.0000e+00, 5.0000e-01, 3.3333e-01, ..., 2.0000e-06, 2.0000e-06,
2.0000e-06]))
You can also override existing functionality.
[5]:
@cola.dispatch
def inv(A: cola.ops.Dense, alg: Algorithm):
print("I overrided the dense inverse")
return cola.ops.Dense(torch.linalg.inv(A.to_dense()))
A = cola.ops.Dense(torch.arange(1, 5).reshape(2, 2).float())
invA = inv(A)
print(invA.to_dense())
I overrided the dense inverse
tensor([[-2.0000, 1.0000],
[ 1.5000, -0.5000]])
We can also implement entirely new linear algebra functions on existing objects, just make sure to have a base case.
For example, lets define a rowsum function that sums the rows of a LinearOperator.
[6]:
@cola.dispatch
def rowsum(A: cola.ops.LinearOperator):
print("dispatched base case")
return A @ A.xnp.ones(A.shape[:1], dtype=A.dtype, device=A.device)
@cola.dispatch
def rowsum(A: MyDiagonal):
print("dispatched on MyDiagonal")
return A.diag
A = MyDiagonal(torch.arange(5))
print(rowsum(A))
dispatched on MyDiagonal
tensor([0, 1, 2, 3, 4])
[7]:
print(rowsum(cola.ops.Dense(torch.arange(4).reshape(2, 2))))
dispatched base case
tensor([1, 5])
TODO: Add example of parametric dispatch for woodbury formula