from dataclasses import dataclass
from cola.linalg.algorithm_base import Algorithm
from cola.linalg.decompositions.arnoldi import arnoldi
from cola.ops import Array, LinearOperator
from cola.utils import export
from cola.utils.custom_autodiff import iterative_autograd
[docs]
@export
@dataclass
class GMRES(Algorithm):
"""
Generalized Minimal Residual algorith (GMRES) for soving Ax=b or AX=B (multiple rhs).
The runtime is bounded by :math:`O(\\sqrt{\\kappa})` and
it uses :math:`O(m n)` memory.
Where :math:`\\kappa` is the condition number of the linear operator,
n is the size of A and m represents the max iters.
Optionally, you can use a preconditioner (approx of Aโปยน) to accelerate convergence.
Args:
tol (float, optional): Relative error tolerance for Arnoldi.
max_iters (int, optional): The maximum number of iterations to run in Arnoldi.
pbar (bool, optional): Whether to show progress bar.
x0 (Array, optional): (n,) or (n, b) guess for initial solution.
P (LinearOperator, optional): Preconditioner. Defaults to the identity.
"""
tol: float = 1e-6
max_iters: int = 1000
pbar: bool = False
x0: Array = None
P: LinearOperator = None
def __call__(self, A, b):
return gmres(A, b, **self.__dict__)
def gmres(A: LinearOperator, rhs: Array, x0=None, max_iters=100, tol=1e-7, P=None, use_householder=False,
use_triangular=False, pbar=False):
"""
Solves Ax=b or AX=B using GMRES.
Args:
A (LinearOperator): A linear operator of size (n, n).
rhs (Array): A single right hand side (n,) or multiple right hand sides (n, b).
x0 (Array, optional): (n,) or (n, b) initial solution guess.
Defaults to the zero vector.
max_iters (int, optional): The maximum number of iterations to run.
tol (float, optional): The tolerance for convergence.
P (array, optional): Preconditioner. Defaults to the Identity.
use_householder (bool, optional): Use Householder Arnoldi variatnt
use_triangular (bool, optional): Use triangular QR factorization.
pbar (bool, optional): show a progress bar.
Returns:
tuple:
- soln (Array): solution to the linear system, either (n,) or (n, b)
- info (dict): general information about the iterative procedure.
"""
xnp = A.xnp
is_vector = len(rhs.shape) == 1
if x0 is None:
x0 = xnp.zeros_like(rhs)
if is_vector:
rhs = rhs[..., None]
x0 = x0[..., None]
soln, infodict = gmres_fwd(A=A, rhs=rhs, x0=x0, max_iters=max_iters, tol=tol, P=P, use_householder=use_householder,
use_triangular=use_triangular, pbar=pbar)
if is_vector:
soln = soln[:, 0]
return soln, infodict
def gmres_bwd(res, grads, unflatten, *args, **kwargs):
y_grads = grads[0]
op_args, output = res
soln = output[0]
A = unflatten(op_args)
xnp = A.xnp
db, _ = gmres_fwd(A, y_grads, *args[1:], **kwargs)
def fun(*theta):
Aop = unflatten(theta)
return Aop @ soln
d_params = xnp.vjp_derivs(fun, op_args, -db)
dA = unflatten(d_params)
return (dA, db)
@iterative_autograd(gmres_bwd)
def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pbar):
xnp = A.xnp
res = rhs - A @ x0 # (m,k)
Q, H, infodict = arnoldi(A=A, start_vector=res, max_iters=max_iters, tol=tol, pbar=pbar,
use_householder=use_householder)
Q, H = Q.to_dense(), H.to_dense()
Q, H = Q[:, :, :-1], H[:, :-1, :]
beta = xnp.norm(res, axis=-2)
e1 = xnp.zeros(shape=(H.shape[1], beta.shape[0]), dtype=rhs.dtype, device=A.device)
e1 = xnp.update_array(e1, beta, 0)
if use_triangular:
# NOTE::: this will not work with multiple rhs Andres to fix
R, Gs = get_hessenberg_triangular_qr(H[0, :, :], xnp=xnp)
target = apply_givens_fwd(Gs, e1, xnp)
y = xnp.solvetri(R, target, lower=False)
pred = Q[0, :, :] @ y
else:
HT = xnp.conj(xnp.permute(H, axes=[0, 2, 1]))
largest_vals = xnp.max(xnp.abs(H), -1)
overall_max = xnp.max(largest_vals.reshape(largest_vals.shape[0], -1), -1)
zero_thresh = 10 * tol * overall_max[:, None]
padding = xnp.where(largest_vals < zero_thresh, xnp.ones_like(largest_vals), xnp.zeros_like(largest_vals))
added_diag = xnp.vmap(xnp.diag)(padding)
y = xnp.solve(HT @ H + added_diag, HT[..., 0, None]).squeeze(-1) * beta[:, None]
zeros = xnp.zeros_like(y)
y = xnp.where(largest_vals < zero_thresh, zeros, y)
pred = xnp.permute(Q @ y[..., None], axes=[1, 0, 2])[:, :, 0]
soln = x0 + pred
return soln, infodict
def get_hessenberg_triangular_qr(H, xnp):
device = xnp.get_device(H)
R = xnp.copy(H)
Gs = []
for jdx in range(H.shape[0] - 1):
cx, sx = get_givens_cos_sin(R[jdx, jdx], R[jdx + 1, jdx], xnp)
G = xnp.array([[cx, sx], [-sx, cx]], dtype=H.dtype, device=device)
Gs.append(G)
update = G.T @ R[[jdx, jdx + 1], :]
R = xnp.update_array(R, update, [jdx, jdx + 1])
return R, Gs
def apply_givens_fwd(Gs, vec, xnp):
for jdx in range(len(Gs)):
update = Gs[jdx].T @ vec[[jdx, jdx + 1], :]
vec = xnp.update_array(vec, update, [jdx, jdx + 1])
return vec
def get_givens_cos_sin(a, b, xnp):
if b == 0:
c, s = 1, 0
else:
denom = xnp.sqrt(a**2. + b**2.)
s = -b / denom
c = a / denom
return c, s