Walkthrough of Linear Algebra Functionalityο
With this section we will explore how to use various linear algebra operations with an emphasis on the dense and iterative base cases, rather than the dispatch rules. For high level operations like solve
, sqrt
, logdet
, eigs
, exp
, trace
the specific functions we call are listed in high-level API, and low level iterative functions that help implement these operations such as cg
, gmres
, lanczos
,
approx_diag
, stochastic_lanczos_quadrature
. These low level algorithms can also be called explicitly, but they will not be able to leverage our dispatch rules, and they sometimes require more involvement and knowledge from the user.
First letβs set up some Linear operators to test these operations on. As we want to focus on the base cases, we will use a simple low rank + diagonal linear operator. We will use \(A=UU^T+D\) as a prototype for a PSD matrix, \(B=UV^T+D\) as a prototype for a generic square matrix, and \(C=V^TD\) as a prototype for a generic rectangular matrix.
[1]:
import torch
import numpy as np
import cola
N = 200
U = cola.lazify(torch.randn(N, 5))
V = cola.lazify(torch.randn(N, 5))
D = cola.ops.Diagonal(torch.linspace(1, 100, N))
A = U @ U.T + D # a PSD matrix
B = U @ V.T + D # a generic square matrix
C = V.T @ D # a generic rectangular matrix
x = torch.ones(N) # test vector x
Sometimes CoLA is able to infer additional properties of a Linear Operator such as PSD
, SelfAdjoint
, or Unitary
but not always, so itβs best to annotate these properties explicitly.
In order to let CoLA know that A
is PSD, we will annotate it with the PSD annotation:
[2]:
print("Properties before annotating:", A.annotations)
A = cola.PSD(A)
print("Properties after annotating:", A.annotations)
Properties before annotating: set()
Properties after annotating: {PSD}
Letβs plot the spectrum of \(A\) to get a sense for the kind of object that we are dealing with.
[3]:
import matplotlib.pyplot as plt
plt.plot(torch.linalg.eigh(A.to_dense())[0])
plt.yscale('log')
plt.ylabel("eigenvalues")
plt.xlabel("index")
[3]:
Text(0.5, 0, 'index')

Inverses / Linear Solvesο
For solving a linear system \(Ax=b\), we may consider two cases where \(A\) is a Positive-SemiDefinite symmetric matrix (PSD), and when it is not.
cola.linalg.inv(A)
represents the linear operator \(A^{-1}\) that, when applied to a vector \(b\), it solves the linear system \(A x = b\) and then outputs \(x\). It does not, however, compute the inverse densely \(A^{-1}\) and then multiply it by the vector \(b\), and is exactly equivalent to calling cola.linalg.solve
as shown below.
PSDο
[4]:
# these two are exactly equivalent in CoLA
y = cola.linalg.solve(A, x)
y = cola.linalg.inv(A) @ x
However with cola.linalg.inv
we can examine the properties of the solve like the number of iterations it as taken to converge. Instead using the default alg=Auto()
, letβs explicitly pick alg=Cholesky()
to use a dense \(O(n^3)\) method or alg=CG()
to use an iterative \(O(\tau \sqrt{\kappa}\log 1/\epsilon)\) where \(\tau\) is the time for an MVM with \(A\), \(\kappa\) is the condition number of \(A\) and \(\epsilon\) is the desired error tolerance.
[5]:
for alg in [cola.Cholesky(), cola.CG()]:
Ainv = cola.linalg.inv(A, alg=alg)
y = Ainv @ x
print(f"With {alg}: Ainv of type {type(Ainv)}")
if isinstance(alg, cola.CG):
print(f"Computed inverse in {Ainv.info['iterations']} iters with error {Ainv.info['errors'][-1]:.1e}")
With <cola.linalg.decompositions.decompositions.Cholesky object at 0x7f8b60107dc0>: Ainv of type <class 'cola.ops.operators.Product[cola.linalg.inverse.inv.TriangularInv[cola.ops.operators.Triangular], cola.linalg.inverse.inv.TriangularInv[cola.ops.operators.Triangular]]'>
With CG(tol=1e-06, max_iters=1000, pbar=False, x0=None, P=None): Ainv of type <class 'cola.linalg.algorithm_base.IterativeOperatorWInfo[cola.ops.operators.Sum[cola.ops.operators.Product[cola.ops.operators.Dense, cola.ops.operators.Dense], cola.ops.operators.Diagonal], cola.linalg.inverse.cg.CG]'>
Computed inverse in 69 iters with error 1.9e-06
In the dense case the inverse
LinearOperator is computed via Cholesky decomposition \(A = LL^T\), and then using \(A^{-1} = L^{-T}L^{-1}\) where \(L^{-1}\) simply means a Linear Operator that performs triangular solves with \(L\). We can see this reflected in the type of Ainv
.
Meanwhile the iterative algorithm uses Conjugate Gradients to perform multiplies with \(A^{-1}\). With the iterative methods one often may want to specify an error tolerance or a maximum number of iterations to limit the computation.
[6]:
Ainv = cola.linalg.inv(A, cola.CG(tol=1e-8, max_iters=1_000))
y = Ainv @ x
print(f"Computed inverse in {Ainv.info['iterations']} iters with error {Ainv.info['errors'][-1]:.1e}")
Computed inverse in 85 iters with error 1.4e-08
With this higher error tolerance, CG can converge in many fewer iterations, and in general one should choose error tolerances as high as are acceptable.
Non PSDο
Letβs call the same functions but with the non PSD operator \(B\).
[7]:
for alg in [cola.LU(), cola.GMRES(tol=1e-3, max_iters=100)]:
Binv = cola.linalg.inv(B, alg=alg)
y = Binv @ x
print(f"With {alg}: Ainv of type {type(Binv)}")
if isinstance(alg, cola.GMRES):
print(f"Computed inverse in {Binv.info['iterations']} iters with error {Binv.info['errors'][-1]:.1e}")
print(f"Actual residual error: {torch.linalg.norm(B @ y - x) / torch.linalg.norm(x):1.3e}")
With <cola.linalg.decompositions.decompositions.LU object at 0x7f8b60192380>: Ainv of type <class 'cola.ops.operators.Product[cola.linalg.inverse.inv.TriangularInv[cola.ops.operators.Triangular], cola.linalg.inverse.inv.TriangularInv[cola.ops.operators.Triangular], cola.ops.operators.Permutation]'>
With GMRES(tol=0.001, max_iters=100, pbar=False, x0=None, P=None): Ainv of type <class 'cola.linalg.algorithm_base.IterativeOperatorWInfo[cola.ops.operators.Sum[cola.ops.operators.Product[cola.ops.operators.Dense, cola.ops.operators.Dense], cola.ops.operators.Diagonal], cola.linalg.inverse.gmres.GMRES]'>
Computed inverse in 101 iters with error 2.5e+01
Actual residual error: 1.490e-04
In the first case, CoLA performs a PLU decomposition \(B=P^{-1}LU\) and then computes the inverse as \(U^{-1}L^{-1}P\) again using the implicit triangular solves, but this time combined with a permutation inverse.
In the second case, CoLA uses GMRES. Notably GMRES uses Arnoldi as a component of the algorithm, and the convergence criteria for Arnoldi depends on more than just residual errors that we care about for the linear solve. As a result of this more stringent convergence criteria, GMRES hits max_iters
before stopping from the Arnoldi convergence tolerance, and the errors can be more easily controlled via changing max_iters
.
In general CG has much more favorable properties over GMRES since with GMRES the memory requirement will grow with the number of iterations as will the compute, and hence if a matrix is PSD annotating it as such is highly preferable.
Matrix Functions: \(A^{1/2}\), \(\exp(A)\), \(\log(A)\), \(f(A)\)ο
Broadly we can consider many functions linear algebra operations \(A^{1/2}\), \(\exp(A)\), \(\log(A)\) as instances of \(f(A)\) where \(f\) is assumed to have a convergent Taylor expansion within the bounds of the spectrum of \(A\).
For alg = Auto()
, we compute these functions performing an eigendecomposition \(A=P\Lambda P^{-1}\) and evaluate \(f(A)=P f(\Lambda) P^{-1}\), and this runs in time \(O(n^3)\).
For alg = Lanczos()
or alg = Arnoldi()
, we evaluate matrix vector products \(f(A)v\) using the Lanczos and Arnoldi processes starting with the vector \(v\). Running in time \(O(\tau \sqrt{\kappa}\log 1/\epsilon+m^3+mn)\) where \(m\) is max_iters
, this approach gives \(\epsilon\) accurate evaluation of \(f(A)v\). (With some extra effort the \(m^3\) term can be reduced to \(m^2\) but we have not yet implemented this optimization.)
For different \(v\) the process will be run again, as Lanczos needs \(v\) to construct a low error approximation.
SelfAdjointο
Whether CoLA uses the more efficient Lanczos process or more costly version using Arnoldi depends on whether or not the matrix is SelfAdjoint
(which guarantees an orthogonal eigenbasis). SelfAdjoint
is considered a superset of PSD
.
[8]:
print(A.annotations)
print(f"PSD: {A.isa(cola.PSD)}, SelfAdjoint: {A.isa(cola.SelfAdjoint)}")
{PSD}
PSD: True, SelfAdjoint: True
[9]:
for alg in [cola.Auto(), cola.Lanczos(tol=1e-4, max_iters=100)]:
S = cola.linalg.sqrt(A, alg=alg)
print(f"S with method={alg} is of type {type(S)}")
print("error in sqrt:", torch.linalg.norm(S @ (S @ x) - A @ x) / torch.linalg.norm(A @ x))
S with method=Auto() is of type <class 'cola.ops.operators.Product[cola.ops.operators.Dense, cola.ops.operators.Diagonal, cola.ops.operators.Dense]'>
error in sqrt: tensor(2.0802e-06)
S with method=Lanczos(start_vector=None, max_iters=100, tol=0.0001, pbar=False) is of type <class 'cola.linalg.unary.unary.LanczosUnary[cola.ops.operators.Sum[cola.ops.operators.Product[cola.ops.operators.Dense, cola.ops.operators.Dense], cola.ops.operators.Diagonal], function]'>
error in sqrt: tensor(1.9483e-06)
/home/ubu/venv/cola/lib/python3.10/site-packages/beartype/_util/hint/pep/utilpeptest.py:311: BeartypeDecorHintPep585DeprecationWarning: PEP 484 type hint typing.Callable deprecated by PEP 585. This hint is scheduled for removal in the first Python version released after October 5th, 2025. To resolve this, import this hint from "beartype.typing" rather than "typing". For further commentary and alternatives, see also:
https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations
warn(
Likewise we can use exp
, log
, pow
, and apply_fn
[10]:
expA = cola.linalg.exp(-A)
logA = cola.linalg.log(A)
Apow10 = cola.linalg.pow(A, 10)
resolvent = cola.linalg.apply_unary(lambda x: 1 / (x - 1), A)
for op in [expA, logA, Apow10, resolvent]:
print(op[:2, :2].to_dense())
tensor([[ 0.2103, -0.0531],
[-0.0531, 0.1666]])
tensor([[0.6478, 0.3039],
[0.3039, 0.7369]])
tensor([[1.7403e+23, 1.8028e+23],
[1.8028e+23, 1.9524e+23]])
tensor([[ 2.2539, -0.7511],
[-0.7511, 1.4565]])
Non SelfAdjointο
For non self adjoint matrices the situation is the same, but it will use Arnoldi iterations. However, some of the operations like log
and sqrt
require that eigenvalues are \(>0\) and may return Nan, Inf, or complex values depending on the inputs.
[11]:
ops = [cola.linalg.sqrt(B), cola.linalg.exp(-B), cola.linalg.pow(B, 10)]
for op in ops:
print(op[:2, :2].to_dense())
/home/ubu/cola/cola/backends/torch_fns.py:101: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:299.)
p_ids = (P @ torch.arange(P.shape[-1]).to(P.device, P.dtype)).to(torch.long)
tensor([[ 2.3205+4.6194e-07j, -0.6258+2.6822e-07j],
[ 1.8426+4.8429e-07j, 0.8871+2.2352e-07j]])
tensor([[ 686.1265-7.6294e-05j, -168.1782+3.0816e-05j],
[ 567.8098-4.5776e-05j, -364.0045-1.8716e-05j]])
tensor([[ 1.4590e+18+2.1475e+10j, -9.3808e+18+2.3782e+12j],
[ 1.0919e+18-9.4489e+10j, -4.7317e+18+1.8495e+12j]])
Trace and Diagonal estimationο
Unlike for dense matrices, evaluating the trace or diagonal of a LinearOperators
can be difficult (consider for example the diagonal of the Hessian of a neural network). Depending on the needs for the problem at hand we provide multiple different solutions:
an exact \(O(n\tau)\) compute and \(O(\tau)\) memory evaluation which loops over the basis elements \(A_{ii} = e_i^TAe_i\)
an stochastic Hutchinson estimator which is unbiased (always), but has runtime \(O(1/\delta^2)\) where \(\delta\) is the desired tolerance for the standard deviation of the estimate.
These can be chosen directly by specifying alg = Exact()
, alg = Hutch()
or by letting the algorithm decide which will be faster alg = Auto()
based on the specified tolerance for the standard deviation tol
. In general for methods which only need stochastic estimates (such as to be used in SGD) or only need one or two digits of precision, then the stochastic estimator will be faster.
We can use either of these two methods for
computing the trace
computing the diagonal
computing off-diagonal diagonals
[12]:
exact = cola.linalg.trace(A, alg=cola.Exact())
approx = cola.linalg.trace(A, alg=cola.Hutch(tol=3e-2))
print(f"exact trace: {exact}, stochastic trace: {approx}")
exact trace: 11207.5498046875, stochastic trace: 11210.9033203125
[13]:
diag = cola.linalg.diag(A)
off_diag = cola.linalg.diag(A, k=1)
print(f"diagonal: {diag}")
print(f"off-diagonal: {off_diag}")
diagonal: tensor([ 8.5118, 7.3608, 7.6542, 5.3548, 19.0612, 6.4097, 9.7575,
9.3554, 8.1646, 8.7194, 9.9059, 18.2854, 12.0774, 10.2313,
11.9585, 11.6125, 21.9750, 11.9402, 15.4233, 15.2707, 12.4610,
22.2939, 17.6396, 17.7406, 17.1853, 18.1922, 16.9665, 16.7785,
17.1129, 24.1279, 24.2655, 20.6261, 30.2835, 25.2488, 22.8829,
21.5087, 32.3452, 21.1568, 22.0920, 26.7968, 23.7173, 25.8120,
28.0823, 25.8921, 29.1320, 25.5530, 26.5214, 30.8620, 25.8763,
29.4752, 27.4093, 29.1308, 29.8036, 31.7998, 31.0898, 37.9161,
36.4030, 31.3916, 31.9861, 37.5207, 40.0409, 38.4947, 39.8926,
36.0543, 34.5295, 40.3930, 36.6227, 39.6822, 37.7233, 35.6859,
43.5714, 46.2801, 46.4217, 41.1966, 45.3633, 40.9981, 41.2265,
44.7463, 53.7195, 47.1256, 46.2465, 44.3040, 42.5806, 52.0961,
45.1239, 47.3316, 48.3270, 56.3554, 48.5199, 51.2425, 50.5862,
60.9226, 49.5557, 47.8160, 53.8036, 59.4363, 57.2788, 56.2419,
56.7123, 57.7423, 53.7828, 59.3607, 56.2272, 64.2254, 56.6092,
65.7697, 54.5191, 59.3984, 58.8306, 57.5462, 62.9409, 65.3171,
75.3433, 59.4735, 65.2470, 60.6562, 65.1781, 65.6146, 68.9406,
69.9885, 73.6313, 65.3591, 64.2916, 66.5440, 72.6440, 68.2763,
66.7878, 66.3294, 71.9146, 71.1690, 70.9741, 68.7452, 82.5140,
71.5895, 78.3163, 70.6314, 71.6406, 77.4104, 72.2250, 77.0280,
73.4117, 75.4781, 82.2615, 77.7552, 74.7615, 77.1487, 84.0647,
77.4321, 82.0002, 82.1438, 78.9036, 78.4603, 78.7670, 85.3516,
88.9480, 80.9072, 82.9142, 82.9673, 90.1816, 86.2896, 82.0797,
85.9348, 82.7690, 87.6654, 86.7496, 90.5796, 88.2644, 86.2070,
89.0124, 88.8021, 86.6497, 93.3623, 91.8149, 95.2812, 96.5947,
89.7338, 93.2476, 92.0821, 97.3608, 95.8187, 95.3042, 97.2232,
93.6238, 96.4429, 103.5226, 95.2086, 96.9995, 99.2942, 96.2102,
99.1804, 99.8973, 107.0359, 97.8000, 103.5394, 101.4755, 99.7733,
101.5736, 108.3602, 104.8761, 102.9599])
off-diagonal: tensor([ 5.4098e+00, -2.3565e+00, 2.1375e+00, -1.5303e+00, 5.4326e+00,
2.6653e+00, 3.3794e+00, 2.5149e+00, 2.7378e+00, 2.5573e+00,
-8.3340e-01, -3.4569e+00, -2.2258e-01, -2.4627e+00, -5.5316e-01,
-3.9379e-01, 2.5767e+00, -3.2219e+00, -2.7628e+00, 1.2039e+00,
3.1580e+00, -2.4235e+00, -5.5459e-01, -8.2220e-01, -1.1991e+00,
-1.5784e+00, 2.3253e-01, 1.1821e+00, -2.5081e+00, -3.2412e+00,
-9.1381e-01, -3.2497e+00, 8.0047e+00, -5.4489e+00, -1.3657e+00,
-3.4058e+00, 4.3618e-02, -2.6919e-01, 6.9709e-02, 1.9282e+00,
-8.2027e-01, 1.1378e+00, -2.2030e+00, 3.8169e+00, -1.5222e+00,
5.2708e-01, 2.7213e+00, -1.5804e-01, -1.7926e+00, -6.7458e-01,
-3.9837e-02, -2.7240e+00, -1.6068e+00, 1.4020e+00, -1.3658e+00,
-3.5227e-01, 3.8906e-01, 1.5199e+00, -1.5219e+00, -1.7110e+00,
7.8942e-01, 4.7879e-01, 1.8430e+00, 2.0734e+00, -4.5239e-01,
-1.7555e+00, -3.0878e+00, -5.6006e-01, -1.3585e-01, -9.7462e-01,
2.5394e+00, -6.5051e+00, -1.5850e+00, -8.1986e-02, -3.6340e+00,
7.5364e-02, 2.3323e-02, -5.0638e+00, 9.3752e+00, 2.6730e+00,
-4.8874e-01, 6.4025e-01, -2.4747e+00, -1.7823e+00, -9.1593e-02,
-7.7965e-01, 5.6591e+00, -3.8323e+00, -2.0298e+00, -1.4195e+00,
-6.2768e-01, 4.0493e+00, -3.2138e-01, -1.0671e+00, 1.0682e+00,
6.8565e+00, 4.9748e+00, -3.2068e+00, 5.4496e+00, -2.8033e-03,
-3.6221e+00, -1.1001e+00, 1.4346e+00, -7.6809e-01, 4.9448e+00,
2.6304e+00, 6.8537e-01, -1.2727e+00, 8.8673e-01, 1.2259e+00,
-9.7530e-02, 4.8972e+00, 1.9835e+00, -3.0567e+00, 1.9151e+00,
2.3309e+00, -2.0069e+00, -6.4053e+00, -2.1000e+00, -6.4842e+00,
1.6885e+00, -2.2382e+00, -1.4906e+00, 1.7551e+00, -1.2747e+00,
-8.2681e-01, -2.4029e+00, 4.7710e-01, -3.3622e+00, 7.8385e-02,
-1.1262e+00, -3.8599e-02, -4.8654e+00, -6.5824e+00, -7.0692e-01,
1.5344e+00, -1.4729e+00, 1.9749e+00, -1.6118e+00, -2.2130e+00,
2.6464e+00, -1.0389e+00, -2.9799e+00, 2.4764e+00, -1.8658e+00,
2.1516e+00, -3.2745e+00, 1.9874e+00, -2.1297e+00, 1.9302e+00,
8.3467e-01, 1.2279e+00, 3.3645e-01, -1.4170e+00, -3.3533e-01,
9.3108e-01, -3.6564e+00, -8.5537e-01, 5.9034e+00, -4.1498e-01,
-1.6407e-01, 1.7689e+00, -7.5038e-01, 7.1956e-01, 2.3514e+00,
-2.6611e-01, -1.5008e+00, 2.9730e-01, 2.8366e+00, -8.0547e-01,
1.3469e+00, 7.3975e-01, -1.9789e+00, 1.8038e+00, 3.1107e-02,
1.5660e+00, 3.9531e-01, -1.7827e+00, -4.5978e+00, 3.6314e-01,
3.0618e+00, -1.5430e+00, -1.3355e+00, -4.0151e+00, -1.1331e+00,
1.1944e+00, 7.2848e-01, 2.5606e+00, -3.9122e-01, -3.8557e+00,
-3.7839e+00, -3.3416e-01, -1.8735e+00, 1.9559e+00, -2.3395e+00,
-1.7063e-01, 3.3771e+00, -2.2266e+00, 1.4054e+00])
Log Determinantsο
In the dense case, we compute log determinants from the Cholesky or LU decompositions depending on whether the matrix is PSD or not, and this runs in time \(O(n^3)\).
For the iterative case we compute log determinants using the formula \(\log\mathrm{det}(A) = \mathrm{Tr}(\log(A))\) combining together the \(\log\) and \(\mathrm{Tr}\) functions discussed above.
In the special case where only unbiased estimates (or a small number of significant digits) are required for the given use case, we can leverage stochastic lanczos quadrature to get an improved convergence rate.
This choice is exposed to the user through the log_alg
vs trace_alg
options in logdet
and slogdet
.
[14]:
print("Tr(log(A))", cola.linalg.logdet(A, trace_alg=cola.Hutch(tol=1e-4), log_alg=cola.Lanczos(tol=1e-4, max_iters=30)))
print("SLQ:", cola.linalg.logdet(A, log_alg=cola.Lanczos(tol=1e-3, max_iters=10), trace_alg=cola.Auto()))
print("Dense:", cola.linalg.logdet(A))
Tr(log(A)) tensor(740.9275)
SLQ: tensor(740.9714)
Dense: tensor(740.9275)
And we can do this for non PSD matrices too using \(\log \mathrm{det}(A) = \frac{1}{2}\log \mathrm{det}(A^TA)\), however the phase (sign) is lost in this process.
[15]:
print("iterative:", cola.linalg.slogdet(B, log_alg=cola.Arnoldi(tol=1e-3, max_iters=10)))
print("Dense:", cola.linalg.slogdet(B, log_alg=cola.LU()))
iterative: (tensor(1.0000-0.0003j), tensor(727.3146))
Dense: (tensor(1.), tensor(733.4941))
Eigendecompositionο
Another popular linear algebra operation is to find an approximation for the eigenvalues or eigenvectors of given linear operator \(A\). That is, find \(V\) and \(\Lambda\) such that \(AV = V \Lambda\).
[16]:
eig_vals, eig_vecs = cola.eig(A, k=A.shape[0])
Nevertheless, when \(A\) is quite large we cannot afford the time or memory for a full decomposition. We could thus use an iterative eigenvalue algorithm such as power iteration (if we only want the maximum eigenvalue), Lanczos (if we have a symmetric operator) or Arnoldi (which works in general). Letβs see how to run all of those options. So to compute the maximum eigenvalue we could run
[17]:
from cola.linalg.eig.power_iteration import power_iteration
_, eig_max, _ = power_iteration(A, tol=1e-4)
print(f"eigmax: {eig_max}")
print(f"eig_vals[0]: {eig_vals[-1]}")
eigmax: 331.973388671875
eig_vals[0]: 332.0385437011719
Indeed we recover the largest eigenvalue. The algorithm that CoLA
ran in this case was the power method which runs in time \(O(\tau\frac{1}{\Delta} \log 1/\epsilon)\) to compute the top eigenvalue. But we can achieve an accelerated convergence rate \(O(\tau\frac{1}{\sqrt{\Delta}} \log 1/\epsilon)\) if we use Lanczos and Arnoldi, however this comes at an additional memory cost of \(O(nm+m^2)\) where \(m\) is max_iters
.
[18]:
e0, v0 = cola.eig(A, k=A.shape[0], which="SM", alg=cola.Lanczos(tol=1e-4, max_iters=15))
print(f"{e0[-1]}")
WARNING:root:Non keyed randn used. To be deprecated soon.
332.03857421875