from abc import abstractmethod
from numbers import Number
from typing import Any, Tuple, Union
import numpy as np
import cola
from cola.backends import AutoRegisteringPyTree, get_library_fns, np_fns
from cola.utils import export
Array = Dtype = Any
export(Array)
[docs]
@export
class LinearOperator(metaclass=AutoRegisteringPyTree):
""" Linear Operator base class """
_dynamic = {key: False for key in ['xnp', 'shape', 'dtype', 'device', 'annotations']}
__array_ufunc__ = None
def __new__(cls, *args, **kwargs):
""" Creates attributes for the flatten and unflatten functionality. """
obj = super().__new__(cls)
obj.device = find_device([args, kwargs])
return obj
[docs]
def __init__(self, dtype: Dtype, shape: Tuple, matmat=None, annotations={}):
self.dtype = dtype
self.shape = shape
self.xnp = get_library_fns(dtype)
if matmat is not None:
self._matmat = matmat
self.annotations = cola.annotations.get_annotations(self)
# TODO: reform matrices with the new annotations?
self.annotations.update(annotations)
self.device = self.device or self.xnp.get_default_device()
def __setattr__(self, name, value):
if name not in self.__class__._dynamic:
# don't split this into two lines, we want the short circuiting
cond = definitely_dynamic(value) or any(map(is_array, np_fns.tree_flatten(value)[0]))
self.__class__._dynamic[name] = cond
return super().__setattr__(name, value)
def to(self, device, dtype=None):
""" Returns a new linear operator with given device and dtype
WARNING: dtype change is not supported yet. """
params, unflatten = self.flatten()
params = [self.xnp.move_to(p, device=device, dtype=dtype) if self.xnp.is_array(p) else p for p in params]
return unflatten(params)
def isa(self, annotation) -> bool:
""" Returns True if the LinearOperator has the given annotation. """
return any(issubclass(a, annotation) for a in self.annotations)
[docs]
@abstractmethod
def _matmat(self, X: Array) -> Array:
""" Defines multiplication AX of the LinearOperator A with a dense array X (d,k)
where A (self) is shape (c,d)"""
raise NotImplementedError
[docs]
def _rmatmat(self, X: Array) -> Array:
""" Defines multiplication XA of the LinearOperator A with a dense array X (k,d)
where A (self) is shape (d,c). By default uses jvp to compute the transpose."""
XT = X.T
if self.isa(cola.annotations.SelfAdjoint):
return self.xnp.conj(self._matmat(self.xnp.conj(XT)).T)
primals = self.xnp.zeros(shape=(self.shape[1], XT.shape[1]), dtype=XT.dtype, device=self.device)
out = self.xnp.linear_transpose(self._matmat, primals=primals, duals=XT)
return out.T
[docs]
def to_dense(self) -> Array:
""" Produces a dense array representation of the linear operator. """
if 8 * self.shape[-2] < self.shape[-1]:
return self.xnp.eye(self.shape[-2], self.shape[-2], dtype=self.dtype, device=self.device) @ self
else:
return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
@property
def T(self):
""" Matrix Transpose """
return cola.fns.transpose(self)
@property
def H(self):
""" Matrix complex conjugate transpose (aka hermitian conjugate, adjoint)"""
return cola.fns.adjoint(self)
def flatten(self) -> Tuple[Array, ...]:
vals, tree = self.xnp.tree_flatten(self)
def unflatten(params):
return self.xnp.tree_unflatten(tree, params)
return vals, unflatten
def __matmul__(self, X: Array) -> Array:
assert X.shape[0] == self.shape[-1], f"dimension mismatch {self.shape} vs {X.shape}"
if isinstance(X, LinearOperator):
return cola.fns.dot(self, X)
elif len(X.shape) == 1:
return self._matmat(X.reshape(-1, 1)).reshape(-1)
elif len(X.shape) >= 2:
return self._matmat(X)
else:
raise NotImplementedError
def __rmatmul__(self, X: Array) -> Array:
assert X.shape[-1] == self.shape[-2], f"dimension mismatch {self.shape} vs {X.shape}"
if isinstance(X, LinearOperator):
return cola.fns.dot(X, self)
elif len(X.shape) == 1:
return self._rmatmat(X.reshape(1, -1)).reshape(-1)
elif len(X.shape) >= 2:
return self._rmatmat(X)
else:
raise NotImplementedError
def __add__(self, other):
# check if is numbers.Number
if isinstance(other, Number) and other == 0:
return self
return cola.fns.add(self, other)
def __radd__(self, other):
return self.__add__(other)
def __mul__(self, c):
# assert isinstance(c, (int, float)), "c must be a scalar"
return cola.fns.mul(self, c)
def __rmul__(self, c):
return self * c
def __neg__(self):
return -1 * self
def __sub__(self, x):
return self.__add__(-x)
def __truediv__(self, x):
return self.__mul__(1 / x)
def __rtruediv__(self, x):
return self.__mul__(1 / x)
def __str__(self):
# check if class is LinearOperator
if self.__class__.__name__ != 'LinearOperator':
return self.__class__.__name__
alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ'
return alphabet[hash(id(self)) % 24]
def __repr__(self):
M, N = self.shape
dt = 'dtype=' + str(self.dtype)
return '<%dx%d %s with %s>' % (M, N, self.__class__.__name__, dt)
def __getitem__(self, ids: Union[Tuple[int, ...], Tuple[slice, ...]]) -> Union[Array, 'LinearOperator']:
# TODO: add Tuple[List[int],...] and List[Tuple[int,int]] cases
# print(type(ids))
# print(type(ids[0]), type(ids[1]))
# check if first element is ellipsis
xnp = self.xnp
from cola.ops import Sliced
match ids:
case int(i):
ei = xnp.canonical(loc=i, shape=(self.shape[-1], ), dtype=self.dtype, device=self.device)
return (self.T @ ei)
case (slice() | xnp.ndarray() | np.ndarray()) as s_i:
return Sliced(A=self, slices=(s_i, slice(None)))
case b, int(j):
ej = xnp.canonical(loc=j, shape=(self.shape[-1], ), dtype=self.dtype, device=self.device)
return (self @ ej)[b]
case int(i), b:
ei = xnp.canonical(loc=i, shape=(self.shape[-1], ), dtype=self.dtype, device=self.device)
return (self.T @ ei)[b]
case (slice() | xnp.ndarray() | np.ndarray()) as s_i, \
(slice() | xnp.ndarray() | np.ndarray()) as s_j:
return Sliced(A=self, slices=(s_i, s_j))
case list(li), list(lj):
out = []
for idx, jdx in zip(li, lj):
# TODO: batch jdx
ej = xnp.canonical(loc=jdx, shape=(self.A.shape[-1], ), dtype=self.dtype, device=self.device)
out.append((self.A @ ej)[idx])
return xnp.stack(out)
case _:
raise NotImplementedError(f"__getitem__ not implemented for this case {type(ids)}")
def tree_flatten(self):
# separate all_elems into pytrees and aux
pytrees, aux = [], []
for key, val in sorted(vars(self).items()):
if self._dynamic[key]:
pytrees.append(val)
aux.append((key, ))
else:
aux.append((key, val))
return pytrees, aux
@classmethod
def tree_unflatten(cls, aux, children):
fields = {}
child_iter = iter(children)
for keyv in aux:
if len(keyv) == 1:
fields[keyv[0]] = next(child_iter)
else:
fields[keyv[0]] = keyv[1]
obj = object.__new__(cls)
for k, v in fields.items():
if k in ['device']: # ,'dtype']: TODO: also separate dtype in case .to was called
continue
setattr(obj, k, v)
# dtypes = [dt for dt in map(maybe_get_dtype, children) if dt is not None]
# obj.dtype = reduce(obj.xnp.promote_types, dtypes) if len(dtypes) > 0 else None
obj.device = find_device(fields) or fields['device']
return obj
def maybe_get_dtype(obj):
try:
return obj.dtype
except AttributeError:
return None
def is_array(obj):
try:
return get_library_fns(obj.dtype).is_array(obj)
except (ImportError, AttributeError):
return False
def is_xnp_array(obj, xnp):
if not hasattr(obj, 'dtype'):
return False
if xnp.is_array(obj):
return True
return False
def find_xnp(obj):
if is_array(obj):
return get_library_fns(obj.dtype)
elif isinstance(obj, LinearOperator) and obj.xnp is not None:
return obj.xnp
elif isinstance(obj, (tuple, list, set)):
for ob in obj:
xnp = find_xnp(ob)
if xnp is not None:
return xnp
elif isinstance(obj, dict):
for _, ob in obj.items():
xnp = find_xnp(ob)
if xnp is not None:
return xnp
try:
return get_library_fns(obj)
except (ImportError, AttributeError):
pass
return None
def find_device(obj):
if is_array(obj):
xnp = get_library_fns(obj.dtype)
return xnp.get_device(obj)
elif isinstance(obj, LinearOperator):
return obj.device
elif isinstance(obj, (tuple, list, set)):
for ob in obj:
device = find_device(ob)
if device is not None:
return device
elif isinstance(obj, dict):
for _, ob in obj.items():
device = find_device(ob)
if device is not None:
return device
return None
def definitely_dynamic(obj):
return is_array(obj) or isinstance(obj, LinearOperator)