Source code for cola.linalg.inverse.inv

import numpy as np
from plum import dispatch, parametric

from cola.annotations import PSD, Unitary
from cola.linalg.algorithm_base import Algorithm, Auto, IterativeOperatorWInfo
from cola.linalg.decompositions.decompositions import LU, Cholesky, cholesky, plu
from cola.linalg.inverse.cg import CG
from cola.linalg.inverse.gmres import GMRES
from cola.ops.operators import (
    BlockDiag,
    Diagonal,
    Identity,
    Kronecker,
    LinearOperator,
    Permutation,
    Product,
    ScalarMul,
    Triangular,
)
from cola.utils import export


[docs] @export def solve(A, b, alg=Auto()): """ Computes Linear solve of a linear operator. Equivalent to cola.inv Args: A (LinearOperator): The linear operator to compute the inverse of. b (Array): The right hand side of the linear system of shape (d, ) or (d, k) alg (Algorithm, optional): The algorithm to use for the solves. Returns: Array: The solution of the linear system of shape (d, ) or (d, k) Example: >>> A = MyLinearOperator() >>> x = cola.solve(A, b, alg=Auto(max_iters=10, pbar=True)) """ return inv(A, alg) @ b
@export @dispatch.abstract def inv(A: LinearOperator, alg: Algorithm = Auto()): """(lazily) computes the inverse of a linear operator, equivalent to solve. Args: A (LinearOperator): The linear operator to compute the inverse of. alg (Algorithm, optional): The algorithm to use for the solves. Returns: LinearOperator: The inverse of the linear operator. Example: >>> A = MyLinearOperator() >>> x = cola.inverse(A, alg=CG(tol=1e-3)) @ b """ @dispatch def inv(A: LinearOperator, alg: GMRES): return IterativeOperatorWInfo(A, alg) @dispatch(precedence=-1) def inv(A: LinearOperator, alg: CG): assert A.isa(PSD), "CG only valid for PSD matrices, wrap in cola.PSD if desired" return IterativeOperatorWInfo(A, alg) @export @dispatch(precedence=-1) def inv(A: LinearOperator, alg: Auto): """ Auto: - if A is PSD and small, use Cholesky - if A is PSD and large, use CG - if A is not PSD and small, use LU - if A is not PSD and large, use GMRES """ match (A.isa(PSD), bool(np.prod(A.shape) <= 1e6)): case (True, True): alg = Cholesky() case (True, False): alg = CG(**alg.__dict__) case (False, True): alg = LU() case (False, False): alg = GMRES(**alg.__dict__) case _: assert False return inv(A, alg) @dispatch(precedence=-1) def inv(A: LinearOperator, alg: Cholesky): assert A.isa(PSD), "Cholesky only valid for PSD matrices, wrap in cola.PSD if desired" L = cholesky(A) return inv(L.H) @ inv(L) @dispatch(precedence=-1) def inv(A: LinearOperator, alg: LU): P, L, U = plu(A) return inv(U) @ inv(L) @ inv(P) @dispatch(cond=lambda A, *_: A.isa(Unitary)) def inv(A: LinearOperator, alg: Algorithm): return Unitary(A.H) @dispatch def inv(A: Identity, alg: Algorithm): return A @dispatch def inv(A: ScalarMul, alg: Algorithm): return ScalarMul(1 / A.c, shape=A.shape, dtype=A.dtype, device=A.c.device) @dispatch def inv(A: Permutation, alg: Algorithm): return Permutation(A.xnp.argsort(A.perm), A.dtype) @dispatch(cond=lambda A, *_: all([M.shape[-2] == M.shape[-1] for M in A.Ms])) def inv(A: Product, alg: Algorithm): output = reversed([inv(M, alg) for M in A.Ms]) return Product(*output) @dispatch def inv(A: BlockDiag, alg: Algorithm): return BlockDiag(*[inv(M, alg) for M in A.Ms], multiplicities=A.multiplicities) @dispatch def inv(A: Kronecker, alg: Algorithm): return Kronecker(*[inv(M, alg) for M in A.Ms]) @dispatch def inv(A: Diagonal, alg: Algorithm): return Diagonal(1. / A.diag)
[docs] @dispatch def inv(A: Triangular, alg: Algorithm): return TriangularInv(A)
@parametric class TriangularInv(LinearOperator): def __init__(self, A: Triangular): super().__init__(A.dtype, A.shape) self.A = A.to_dense() self.lower = A.lower def _matmat(self, X): return self.xnp.solvetri(self.A, X, lower=self.lower) def _rmatmat(self, X): return self.xnp.solvetri(self.A.T, X.T, lower=not self.lower).T