Source code for cola.linalg.eig.eigs

import numpy as np
from plum import dispatch

from cola.annotations import SelfAdjoint, Stiefel, Unitary
from cola.fns import lazify
from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.decompositions.arnoldi import arnoldi_eigs
from cola.linalg.decompositions.decompositions import Arnoldi, Lanczos, get_slice
from cola.linalg.decompositions.lanczos import lanczos_eigs
from cola.linalg.eig.lobpcg import LOBPCG, lobpcg
from cola.linalg.eig.power_iteration import PowerIteration
from cola.linalg.trace.diag_trace import diag
from cola.linalg.unary.unary import Eig, Eigh
from cola.ops.operator_base import LinearOperator
from cola.ops.operators import Diagonal, I_like, Identity, Triangular
from cola.utils import export


@export
@dispatch.abstract
def eig(A: LinearOperator, k: int, which: str = "LM", alg: Algorithm = Auto()):
    """
    Computes eigenvalues and eigenvectors of a linear operator.

    Args:
        A (LinearOperator): The linear operator for which eigenvalues and eigenvectors are computed.
        k (int): The desired number of eigenvalues and eigenvectors. Must be specified.
        which (str): From what part of the spectrum would de eigenvalues be fetched.
         Default is 'LM' (largest in magnitude) but alternatively you can use 'SM'
         (smallest in magnitude).
        alg (Algorithm): (Auto, Eig, Eigh, Arnoldi, Lanczos)

    Returns:
        Tuple[Array, LinearOperator]: A tuple containing eigenvalues and eigenvectors.
         The eigenvalues are given by eig_vals and the eigenvectors are given
         by eig_vecs.

    Example:
        >>> A = MyLinearOperator()
        >>> eig_vals, eig_vecs = eig(A, k=6, which='LM', alg=Auto(tol=1e-4))
    """


@export
def eigmax(A: LinearOperator, alg: Algorithm = Auto()):
    """ Find the maximum magnitude eigenvalue.
        Equivalent to eig(A, k=1, which='LM')

    Args:
        A (LinearOperator): The linear operator for which eigenvalues and eigenvectors are computed.
        alg (Algorithm): (Auto, PowerIteration, Eig, Eigh, Arnoldi, Lanczos)

    Returns:
        float: The maximum magnitude eigenvalue.
    """
    es, vs = eig(A, k=1, which='LM', alg=alg)
    return es[0]


@export
def eigmin(A: LinearOperator, alg: Algorithm = Auto()):
    """ Find the minimum magnitude eigenvalue.
        Equivalent to eig(A, k=1, which='SM')

    Args:
        A (LinearOperator): The linear operator for which eigenvalues and eigenvectors are computed.
        alg (Algorithm): (Auto, PowerIteration, Eig, Eigh, Arnoldi, Lanczos)

    Returns:
        float: The minimum magnitude eigenvalue.
    """
    es, vs = eig(A, k=1, which='SM', alg=alg)
    return es[0]


@dispatch(precedence=-1)
def eig(A: LinearOperator, k: int, which: str, alg: Auto):
    """ Auto:
        - if A is Hermitian and small, use Eigh
        - if A is Hermitian and large, use Lanczos
        - if A is not Hermitian and small, use Eig
        - if A is not Hermitian and large, use Arnoldi
    """
    SA = A.isa(SelfAdjoint)
    small = bool(np.prod(A.shape) <= 1e6)
    if k == 1 and which == 'LM':
        algorithm = PowerIteration(**alg.__dict__)
    elif SA and small:
        algorithm = Eigh()
    elif not SA and small:
        algorithm = Eig()
    elif SA and not small:
        algorithm = Lanczos(**alg.__dict__)
    elif not SA and not small:
        algorithm = Arnoldi(**alg.__dict__)
    return eig(A, k, which, algorithm)


@dispatch
def eig(A: LinearOperator, k: int, which: str, alg: Arnoldi):
    eig_slice = get_slice(k, which)
    eig_vals, eig_vecs, _ = arnoldi_eigs(A, **alg.__dict__)
    return eig_vals[eig_slice], eig_vecs[:, eig_slice]


@dispatch
def eig(A: LinearOperator, k: int, which: str, alg: Lanczos):
    assert A.isa(SelfAdjoint), "Lanczos only valid for SelfAdjoint, wrap in cola.SelfAdjoint if desired"
    eig_slice = get_slice(k, which)
    eig_vals, eig_vecs, _ = lanczos_eigs(A, **alg.__dict__)
    return eig_vals[eig_slice], eig_vecs[:, eig_slice]


@dispatch
def eig(A: LinearOperator, k: int, which: str, alg: LOBPCG):
    assert A.isa(SelfAdjoint), "LOBPCG only valid for SelfAdjoint, wrap in cola.SelfAdjoint if desired"
    eig_slice = get_slice(k, which)
    eig_vals, eig_vecs = lobpcg(A, **alg.__dict__)
    return eig_vals[eig_slice], eig_vecs[:, eig_slice]


@dispatch(precedence=-1)
def eig(A: LinearOperator, k: int, which: str, alg: Eig):
    eig_slice = get_slice(k, which)
    eig_vals, eig_vecs = A.xnp.eig(A.to_dense())
    return eig_vals[eig_slice], lazify(eig_vecs[:, eig_slice])


@dispatch(precedence=-1)
def eig(A: LinearOperator, k: int, which: str, alg: Eigh):
    eig_slice = get_slice(k, which)
    eig_vals, eig_vecs = A.xnp.eigh(A.to_dense())
    return eig_vals[eig_slice], Stiefel(lazify(eig_vecs[:, eig_slice]))


@dispatch(precedence=-1)
def eig(A: LinearOperator, k: int, which: str, alg: PowerIteration):
    assert k == 1 and which == 'LM', "PowerIteration only valid for k=1 and which='LM'"
    v, emax, _ = alg(A)
    return emax[None], v[:, None]


@dispatch
def eig(A: Identity, k: int, which: str, alg: Algorithm):
    eig_slice = get_slice(k, which)
    xnp = A.xnp
    eig_vals = xnp.ones(shape=(A.shape[0], ), dtype=A.dtype, device=A.device)
    eig_vecs = A.to_dense()
    return eig_vals[eig_slice], Unitary(lazify(eig_vecs[:, eig_slice]))


@dispatch
def eig(A: Triangular, k: int, which: str, alg: Algorithm):
    # TODO: take out compute_lower_triangular_eigvecs
    eig_slice = get_slice(k, which)
    xnp = A.xnp
    eig_vals = diag(A)
    sorted_ind = xnp.argsort(eig_vals)
    eig_vals = eig_vals[sorted_ind]
    eig_vecs = compute_lower_triangular_eigvecs(np.array(A.A))
    eig_vecs = xnp.array(eig_vecs, dtype=A.dtype, device=A.device)[:, sorted_ind]
    return eig_vals[eig_slice], Unitary(lazify(eig_vecs[:, eig_slice]))


def compute_lower_triangular_eigvecs(L):
    eigvals = np.diag(L)
    eigvecs = np.eye(L.shape[0])
    for i in range(1, L.shape[0]):
        A = L[:i, :i] - eigvals[i] * np.eye(i)
        out = np.linalg.solve(A, -L[:i, i])
        eigvecs[:i, i] = out
    return eigvecs


[docs] @dispatch def eig(A: Diagonal, k: int, which: str, alg: Algorithm): eig_slice = get_slice(k, which) xnp = A.xnp sorted_ind = xnp.argsort(A.diag) eig_vals = A.diag[sorted_ind] eig_vecs = I_like(A).to_dense()[:, sorted_ind] return eig_vals[eig_slice], Unitary(lazify(eig_vecs[:, eig_slice]))