from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
from cola.linalg.algorithm_base import Algorithm
from cola.ops import I_like, LinearOperator
from cola.utils import export
PRNGKey = Any
[docs]
@export
@dataclass
class Exact(Algorithm):
"""
Exact algorithm to compute, element-by-element, the diagonal of the matrix.
That is, each element of the diagoinal :math:`A_{i,i}` is computed as
:math:`e_{i}^{T} A(e_{i})`.
For efficiency, this procedure is done through blocks of elements,
:math:`A [e_{i_{1}}, \\cdots, e_{i_{B}}]` where :math:`B` is the block size.
Args:
bs (int, optional): Block size.
pbar (bool, optional): Whether to show progress bar.
"""
bs: int = 100
pbar: bool = False
def __call__(self, A, k):
return exact_diag(A, k, self.bs)
[docs]
@export
@dataclass
class Hutch(Algorithm):
"""
Hutchinson's algorithm for estimating the trace of a matrix function.
Basically, this algorithm uses random probes to approximate
:math:`\\text{tr}(f(A))`.
Args:
tol (float, optional): Approximation relative tolerance.
max_iters (int, optional): The maximum number of iterations to run.
bs (int, optional): Number of probes.
rand (str, optional): type of random probes (either Normal or Rademacher)
pbar (bool, optional): Whether to show progress bar.
key (xnp.PRNGKey, optional): Random key (default None).
"""
tol: float = 3e-2
max_iters: int = 10_000
bs: int = 100
rand: str = 'normal'
pbar: bool = False
key: Optional[PRNGKey] = None
def __call__(self, A, k):
return hutchinson_diag_estimate(A, k, **self.__dict__)[0]
[docs]
@export
@dataclass
class HutchPP(Algorithm):
"""
Hutch++ is an improvement on the Hutchinson's estimator introduced in
Meyer et al. 2020: Hutch++: Optimal Stochastic Trace Estimator.
Args:
tol (float, optional): Approximation relative tolerance.
max_iters (int, optional): The maximum number of iterations to run.
bs (int, optional): Number of probes.
rand (str, optional): type of random probes (either Normal or Rademacher)
pbar (bool, optional): Whether to show progress bar.
key (xnp.PRNGKey, optional): Random key (default None).
"""
tol: float = 3e-2
max_iters: int = 10000
bs: int = 100
rand: str = 'normal'
pbar: bool = False
key: Optional[PRNGKey] = None
def __call__(self, A, k):
raise NotImplementedError
def get_I_chunk_like(A: LinearOperator, i, bs, shift=0):
xnp = A.xnp
k = shift
Id = I_like(A)
if k == 0:
I_chunk = Id[:, i:i + bs].to_dense()
chunk = I_chunk
shifted_chunk = I_chunk
elif k <= 0:
k = abs(k)
I_chunk = Id[:, i:i + bs + k].to_dense()
padded_chunk = A.xnp.zeros((A.shape[0], bs + k), dtype=A.dtype, device=A.device)
slc = np.s_[:I_chunk.shape[-1]]
padded_chunk = xnp.update_array(padded_chunk, I_chunk, slice(0, None), slc)
chunk = I_chunk[:, :bs]
shifted_chunk = padded_chunk[:, k:k + bs]
else:
I_chunk = Id[:, max(i - k, 0):i + bs].to_dense()
padded_chunk = A.xnp.zeros((A.shape[0], bs + k), dtype=A.dtype, device=A.device)
slc = np.s_[-I_chunk.shape[-1]:]
padded_chunk = xnp.update_array(padded_chunk, I_chunk, slice(0, None), slc)
chunk = I_chunk[:, -bs:]
shifted_chunk = padded_chunk[:, :bs]
return chunk, shifted_chunk
# disable backwards for now, TODO: add tests then add back in
# @iterative_autograd(exact_diag_bwd)
def exact_diag(A, k, bs):
bs = min(100, A.shape[0])
# lazily create chunks of the identity matrix of size bs
diag_sum = 0.
for i in range(0, A.shape[0], bs):
chunk, shifted_chunk = get_I_chunk_like(A, i, bs, k)
diag_sum += ((A @ chunk) * shifted_chunk).sum(-1)
if k <= 0:
out = diag_sum[abs(k):]
else:
out = diag_sum[:(-k or None)]
return out
def exact_diag_bwd(res, grads, unflatten, *args, **kwargs):
v = grads[0] if isinstance(grads, (tuple, list)) else grads
op_args, _ = res
A = unflatten(op_args)
xnp = A.xnp
k = kwargs.get('k')
bs = kwargs.get('bs')
def fun(theta, C, shifted_C):
Aop = unflatten(theta)
out = ((Aop @ C) * shifted_C).sum(-1)
return out[abs(k):] if k <= 0 else out[:(-k or None)]
def d_params(C, shifted_C):
d_paramsv, _, _ = xnp.vjp_derivs(fun, (op_args, C, shifted_C), v)
return d_paramsv
d_p = type(op_args)([0. * arg for arg in op_args])
for i in range(0, A.shape[0], bs):
chunk, shifted_chunk = get_I_chunk_like(A, i, bs, k)
dp_all = d_params(chunk, shifted_chunk)
for i in range(len(d_p)):
d_p[i] += dp_all[i]
dA = unflatten(d_p)
return (dA, )
@export
def hutchinson_diag_estimate(A: LinearOperator, k=0, bs=100, tol=3e-2, max_iters=10000, pbar=False, rand='normal',
key=None):
""" Extract the (kth) diagonal of a linear operator using stochastic estimation
Args:
A (LinearOperator): Linear operator.
k (int, optional): Index of the diagonal to extract (default 0).
bs (int, optional): Chunk size (default 100).
tol (float, optional): Tolerance (default 3e-2).
max_iters (int, optional): Maximum number of iterations (default 10000).
pbar (bool, optional): Flag for showing progress bar.
key (xnp.PRNGKey, optional): Random key (default None).
Returns:
Array: Extracted diagonal elements.
Info: Dictionary with information about the method used.
"""
bs = min(100, A.shape[0])
# lazily create chunks of the identity matrix of size bs
xnp = A.xnp
assert tol > 1e-3, "tolerance chosen too high for stochastic diagonal estimation"
assert rand in ['normal', 'rademacher'], "rand must be 'normal' or 'rademacher'"
key = xnp.PRNGKey(42) if key is None else key
@xnp.jit
def body(state):
# TODO: fix randomness when using with JAX
i, diag_sum, diag_sumsq, key = state
key = xnp.next_key(key)
z = xnp.randn(A.shape[0], bs, dtype=A.dtype, key=key, device=A.device)
if rand == 'rademacher':
z = xnp.sign(z)
z2 = xnp.roll(z, -k, 0)
z2 = xnp.update_array(z2, 0, slice(0, abs(k)) if k <= 0 else slice(-abs(k), None))
slc = slice(abs(k), None) if -k > 0 else slice(None, -abs(k) or None)
estimator = ((A @ z) * z2)[slc]
return i + 1, diag_sum + estimator.sum(-1), diag_sumsq + (estimator**2).sum(-1), key
def err(state):
i, diag_sum, diag_sumsq, _ = state
mean = diag_sum / (i * bs)
stderr = xnp.sqrt((diag_sumsq / (i * bs) - mean**2) / (i * bs))
return xnp.mean(stderr / xnp.maximum(xnp.abs(mean), .1 * xnp.ones_like(mean)))
def cond(state):
return (state[0] == 0) | ((state[0] < max_iters) & (err(state) > tol))
while_loop, infos = xnp.while_loop_winfo(err, tol, max_iters, pbar=pbar)
zeros = xnp.zeros((A.shape[0] - abs(k), ), dtype=A.dtype, device=A.device)
n, diag_sum, *_ = while_loop(cond, body, (0, zeros, zeros, key))
mean = diag_sum / (n * bs)
return mean, infos