Philosophy
The framework of automatic differentiation has revolutionized machine learning. Although the rules that govern derivatives have long been known, automatically computing derivatives was a nontrivial process that required
efficient implementations of base-case primitive derivatives,
software abstractions (autograd and computation graphs) to compose these primitives into complex computations, and
a mechanism for users to modify or extend compositional rules to new functions. Once libraries such as PyTorch, Chainer, Tensorflow, JAX, and others figured out the correct abstractions, the impact was enormous. Efforts that previously went into deriving and implementing gradients could be repurposed into developing new models.
In CoLA
, we automate another notorious bottleneck for ML methods: performing large-scale linear algebra (e.g. matrix solves, eigenvalue problems, nullspace computations). These ubiquitous operations are at the heart of principal component analysis, Gaussian processes, normalizing flows, equivariant neural networks, and many other applications. As with automatic differentiation, structure-aware linear algebra is ripe for automation. We introduce a general numerical framework that dramatically
simplifies implementation efforts while achieving a high degree of computational efficiency. In code, we represent structure matrices as LinearOperator
objects which adhere to the same API as standard dense matrices. For example, a user can call inverse or eig on any LinearOperator
, and under the hood our framework derives a computationally efficient algorithm built from roughly 70 compositional dispatch rules. If little is known about the LinearOperator
, the derived algorithm
reverts to a general-purpose base case (e.g. Gaussian elimination or GMRES for linear solves). Conversely, if the LinearOperator
is known to be the Kronecker product of a lower triangular matrix and a positive definite Toeplitz matrix, for example, the derived algorithm uses specialty algorithms for Kronecker, triangular, and positive definite matrices. Through this compositional pattern matching, our framework can match or outperform special-purpose implementations across numerous
applications despite relying on only 25 base LinearOperator
types.
In the table below we show the presence of dispatch rules (blue square) in our framework across different linear algebra operations (inverse, eig, diagonal, transpose, exp, determinant) and different types of LinearOperators
. Some of the dispatch rules can be derived from a combination of previous ones (red squares).