Source code for cola.linalg.decompositions.decompositions

""" Decompositions of linear operators, e.g. LU, Cholesky, Lanczos, Arnoldi, SVD"""
from dataclasses import dataclass
from typing import Any, Optional

from plum import dispatch

import cola.linalg
from cola.annotations import Unitary
from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.decompositions.arnoldi import arnoldi
from cola.linalg.decompositions.lanczos import lanczos, lanczos_eigs
from cola.ops import (
    Array,
    BlockDiag,
    Diagonal,
    Identity,
    Kronecker,
    LinearOperator,
    Permutation,
    ScalarMul,
    Triangular,
)
from cola.utils import export

PRNGKey = Any


[docs] @export class Cholesky(Algorithm): """ Cholesky algorithm for decomposing a positive definite operator as :math:`A = L L^{*}`, where :math:`L` is a lower triangular operator. Example: >>> A = MyLinearOperator() >>> L = cola.linalg.decompositions.Cholesky()(A) """ def __call__(self, A: LinearOperator): return cholesky(A)
[docs] @export class LU(Algorithm): """ LU algorithm for decomposing a general square operator as :math:`A = PLU`, where :math:`P` is a permutation operator, :math:`L` is a lower triangular operator and :math:`U` is an upper triangular operator. Example: >>> A = MyLinearOperator() >>> P,L,U = cola.linalg.decompositions.LU()(A) """ def __call__(self, A: LinearOperator): return plu(A)
[docs] @export @dataclass class Arnoldi(Algorithm): """ Arnoldi decomposition for a general square operator, :math:`H \\approx Q^{*} A Q` where :math:`H` is an upper Hessenberg operator. This algorithm is used to approximate eig(A) through eig(H). Args: start_vector (Array, optional): (n,) or (n, b) vector to start the algorithm. max_iters (int, optional): The maximum number of iterations to run. tol (float, optional): Relative error tolerance. pbar (bool, optional): Whether to show progress bar. key (PRNGKey, optional): Random key to use for the algorithm. PRNGKey for jax, long integer for numpy or pytorch. Example: >>> A = MyLinearOperator() >>> Q,H,info = Arnoldi(max_iters=100,pbar=True)(A) """ start_vector: Array = None max_iters: int = 1_000 tol: float = 1e-6 pbar: bool = False key: Optional[PRNGKey] = None def __call__(self, A: LinearOperator): return arnoldi(A, **self.__dict__)
@dataclass class LanczosSVD(Algorithm): """ Does the SVD using the Lanczos decomposition, Args: start_vector (Array, optional): (n,) or (n, b) vector to start the algorithm. max_iters (int, optional): The maximum number of iterations to run. tol (float, optional): Relative error tolerance. pbar (bool, optional): Whether to show progress bar. key (PRNGKey, optional): Random key to use for the algorithm. PRNGKey for jax, long integer for numpy or pytorch. """ start_vector: Array = None max_iters: int = 1_000 tol: float = 1e-6 pbar: bool = False key: Optional[PRNGKey] = None def __call__(self, A: LinearOperator): Sigma0_2, V, _ = lanczos_eigs(A.H @ A, **self.__dict__) Sigma1_2, U, _ = lanczos_eigs(A @ A.H, **self.__dict__) Sigma = A.xnp.sqrt(Sigma0_2) return Unitary(U), Diagonal(Sigma), Unitary(V)
[docs] @export @dataclass class Lanczos(Algorithm): """ Lanczos decomposition for Symmetric (or Hermitian) operators, :math:`T = Q^{*} A Q` where :math:`T` is a tridiagional operator. This algorithm is used to approximate eig(A) through eig(T). Args: start_vector (Array, optional): (n,) or (n, b) vector to start the algorithm. max_iters (int, optional): The maximum number of iterations to run. tol (float, optional): Relative error tolerance. pbar (bool, optional): Whether to show progress bar. key (PRNGKey, optional): Random key to use for the algorithm. PRNGKey for jax, long integer for numpy or pytorch. Example: >>> A = MyLinearOperator() >>> Q,T,info = Lanczos(max_iters=100,pbar=True)(A) """ start_vector: Array = None max_iters: int = 1_000 tol: float = 1e-6 pbar: bool = False key: Optional[PRNGKey] = None def __call__(self, A: LinearOperator): return lanczos(A, **self.__dict__)
@dispatch def cholesky(A: LinearOperator): """ Performs a cholesky decomposition A=LL* of a linear operator A. The returned operator L with triangular structure. (Implicitly assumes A is PSD) """ return Triangular(A.xnp.cholesky(A.to_dense()), lower=True) @dispatch def cholesky(A: Identity): return A @dispatch def cholesky(A: Diagonal | ScalarMul): return cola.linalg.sqrt(A, Auto()) @dispatch def cholesky(A: Kronecker): # see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf return Kronecker(*[cholesky(Ai) for Ai in A.Ms]) @dispatch def cholesky(A: BlockDiag): return BlockDiag(*[cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities) @dispatch def plu(A: LinearOperator): """ Performs a cholesky decomposition A=PLU of a linear operator A. The returns P, L, U as linear operators, but represented using the triangular (and permutation) structure """ p, L, U = A.xnp.lu(A.to_dense()) P, L, U = Permutation(p), Triangular(L, lower=True), Triangular(U, lower=False) P, L, U = P.to(A.device), L.to(A.device), U.to(A.device) return P, L, U @dispatch def plu(A: Identity): return A, A, A @dispatch def plu(A: Diagonal | ScalarMul): S = cola.linalg.sqrt(A) return cola.ops.I_like(A), S, S @dispatch def plu(A: Kronecker): # see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf P, L, U = zip(*[plu(Ai) for Ai in A.Ms]) return Kronecker(*P), Kronecker(*L), Kronecker(*U) @dispatch def plu(A: BlockDiag): P, L, U = zip(*[plu(Ai) for Ai in A.Ms]) BD = lambda *args: BlockDiag(*args, multiplicities=A.multiplicities) # noqa return BD(*P), BD(*L), BD(*U) def get_slice(num, which): if num == -1: raise ValueError(f"Number of eigenvalues {num} must be explicitly specified") if which == "SM": eig_slice = slice(0, num, None) elif which == "LM": id = -1 if num is None else -num eig_slice = slice(id, None, None) else: raise NotImplementedError(f"which={which} is not implemented") return eig_slice