Compositional Linear Algebra (CoLA)
CoLA is a multiplatform framework that enables fast linear algebra operations with support for GPU acceleration and autograd. Want to efficiently compute eigenvalues, a matrix inverse, a log determinant, or some other matrix operation, in a framework that supports both JAX and PyTorch? If your matrix has structure – say it has sparsity or can be decomposed as a Kronecker product – then even better. Read on to find out more.
Many areas of machine learning (ML) and science involve large-scale linear algebra problems, such as performing eigendecompositions, solving linear systems, computing matrix exponentials, and doing trace estimation. The linear ops involved often have Kronecker, convolutional, block diagonal, sum, or product structure. Yet, to exploit this structure, that is, in order to use specific algorithms that have faster runtimes than general algorithms, a user must manually implement these efficient routines on a case-by-case basis and be familiar with the different algorithms that exist for different structures. This process leads to a notorious implementation bottleneck!
To eliminate this bottleneck we introduce
CoLA, a numerical linear algebra library designed to
automatically exploit the structure present in a diverse set of linear ops.
To achieve this,
CoLA automatically exploits compositional structure by leveraging over 70 dispatch
rules which select different algorithms for the diverse structure present in a linear
operator. Additionally, given our emphasis on ML applications,
JAX, leverages GPU and TPU acceleration, supports low
precision, provides automatic computation of
gradients, diagonals, transposes and adjoints of linear
ops, and incorporates specialty algorithms such as SVRG and a novel
variation of Hutchinson’s diagonal estimator which exploit the large-scale sum
structure of several linear ops found in ML applications.
Furthermore, regardless of whether there is structure that can be exploited or not,
CoLA can be used as a general purpose numerical linear algebra package
for large-scale linear ops.
CoLA provides an implementation of classical iterative algorithms for
solving linear systems, performing eigendecompositions and more for
PSD, symmetric, non-symmetric, real and complex linear ops.
Below we highlight some of the important features that
CoLA has and how they
compare with alternatives.
We recommend installing via
pip install cola-ml
To install locally instead, clone the repository and install via
git clone https://github.com/wilson-labs/cola cd cola pip install -e .
CoLA requires Python >= 3.10
- The installation requires
JAXto be installed,
and these requirements will not be installed automatically.
JAX is not installed,
CoLA can also use a
Numpy backend for most operations,
excluding advanced features like automatic differentiation and GPU support, vmap, jit, and autograd transposes.
CoLA is designed with the following criteria in mind:
We enable easy extensibility by allowing users to define dispatch rules and linear ops.
We adhere to the same API used for dense matrix operations.
We use multiple dispatch to exploit structure of a linear operator.
We provide support for both PyTorch and JAX.
We leverage automatic differentiation to define operations like transpose but also to derive gradients of linear ops.
- Gaussian Processes from scratch
- 2nd order optimization of neural nets using Gauss Newton
- Computing the eigenspectrum of the Hessian of a Neural Network
- Boundary Value PDEs
- Diagonalizing a Hamiltonian (PDE eigenvalue problems)
- Spectral Clustering