from functools import reduce
import numpy as np
from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.trace.diagonal_estimation import Exact, Hutch, HutchPP
from cola.ops.operators import (
BlockDiag,
Dense,
Diagonal,
I_like,
Identity,
Kronecker,
KronSum,
LinearOperator,
ScalarMul,
Sum,
)
from cola.utils import dispatch, export
@export
@dispatch.abstract
def diag(A: LinearOperator, k: int = 0, alg: Algorithm = Auto()):
r""" Extract the (kth) diagonal of a linear operator.
Can use either the :math:`O(\tfrac{1}{\delta^2})` time stochastic estimation (alg=Hutch())
or a deterministic :math:`O(n)` time algorithm (alg =Exact()).
If only unbiased estimates of the diagonal are needed, use the Hutchinson algorithm.
Args:
A (LinearOperator): The linear operator to compute the diagonal of.
k (int, optional): Specify to compute the kth off diagonal diagonal.
alg (Algorithm, optional): The algorithm to use for the diagonal computation (Hutch or Exact).
Returns:
Array: diag
"""
# ########### BASE CASES #############
@dispatch(precedence=-1)
def diag(A: LinearOperator, k: int, alg: Auto):
tol = alg.__dict__.get("tol", 1e-6)
exact_faster = tol < 1 / np.sqrt(10 * np.prod(A.shape))
if exact_faster:
return diag(A, k, Exact())
else:
return diag(A, k, Hutch(**alg.__dict__))
@dispatch(precedence=-1)
def diag(A: LinearOperator, k: int, alg: Hutch | HutchPP | Exact):
return alg(A, k)
# ############ DISPATCH RULES ############
@dispatch
def diag(A: Dense, k: int, alg: Algorithm):
xnp = A.xnp
return xnp.diag(A.A, diagonal=k)
@dispatch
def diag(A: Identity, k: int, alg: Algorithm):
if k == 0:
return A.xnp.ones((A.shape[0], ), A.dtype, device=A.device)
else:
return A.xnp.zeros((A.shape[0] - abs(k), ), A.dtype, device=A.device)
@dispatch
def diag(A: Diagonal, k: int, alg: Algorithm):
if k == 0:
return A.diag
else:
return A.xnp.zeros((A.shape[0] - abs(k), ), A.dtype, device=A.device)
@dispatch
def diag(A: Sum, k: int, alg: Algorithm):
out = sum(diag(M, k, alg) for M in A.Ms)
return out
@dispatch
def diag(A: BlockDiag, k: int, alg: Algorithm):
assert k == 0, "Havent filled this case yet, need to pad with 0s"
diags = [[diag(M, k, alg)] * m for M, m in zip(A.Ms, A.multiplicities)]
return A.xnp.concat([item for sublist in diags for item in sublist])
@dispatch
def diag(A: ScalarMul, k: int, alg: Algorithm):
return A.c * diag(I_like(A), k, alg)
def product(c):
return reduce(lambda a, b: a * b, c)
@dispatch
def diag(A: Kronecker, k: int, alg: Algorithm):
assert k == 0, "Need to verify correctness of rule for off diagonal case"
ds = [diag(M, k, alg) for M in A.Ms]
# compute outer product of the diagonals
slices = [[None] * i + [slice(None)] + [None] * (len(ds) - i - 1) for i in range(len(ds))]
return product([d[tuple(s)] for d, s in zip(ds, slices)]).reshape(-1)
[docs]
@dispatch
def diag(A: KronSum, k: int, alg: Algorithm):
assert k == 0, "Need to verify correctness of rule for off diagonal case"
ds = [diag(M, k, alg) for M in A.Ms]
# compute outer product of the diagonals
slices = [[None] * i + [slice(None)] + [None] * (len(ds) - i - 1) for i in range(len(ds))]
return sum([d[tuple(s)] for d, s in zip(ds, slices)]).reshape(-1)
@export
@dispatch.abstract
def trace(A: LinearOperator, alg: Algorithm = Auto()):
r""" Compute the trace of a linear operator tr(A).
Can use either the :math:`O(\tfrac{1}{\delta^2})` time stochastic estimation (alg=Hutch())
or a deterministic :math:`O(n)` time algorithm (alg =Exact()).
If only unbiased estimates of the diagonal are needed, use the Hutchinson algorithm.
Args:
A (LinearOperator): The linear operator to compute the diagonal of.
alg (Algorithm, optional): The algorithm to use for the diagonal computation (Hutch or Exact).
Returns:
Array: trace"""
@dispatch
def trace(A: LinearOperator, alg: Algorithm):
assert A.shape[0] == A.shape[1], "Can't trace non square matrix"
return diag(A, 0, alg).sum()
[docs]
@dispatch
def trace(A: Kronecker, alg: Algorithm):
return product([trace(M, alg) for M in A.Ms])