Source code for cola.ops.operators

from functools import partial, reduce

import numpy as np
from plum import parametric
from scipy.sparse import coo_array

import cola
from cola.backends import get_library_fns
from cola.ops.operator_base import Array, LinearOperator


[docs] class Dense(LinearOperator): """ LinearOperator wrapping of a dense matrix. O(n^2) memory and time mvms. Args: A (array_like): Dense matrix to be wrapped. Example: >>> A = jnp.array([[1., 2.], [3., 4.]]) >>> op = Dense(A) """ def __init__(self, A: Array): self.A = A super().__init__(dtype=A.dtype, shape=A.shape) def _matmat(self, X: Array) -> Array: dtype = self.xnp.promote_types(self.dtype, X.dtype) return self.xnp.cast(self.A, dtype) @ self.xnp.cast(X, dtype) def _rmatmat(self, X: Array) -> Array: # TODO: check if this is a correct fix # xnp = self.xnp # self.A = self.A.to(xnp.get_array_device(X)) dtype = self.xnp.promote_types(self.dtype, X.dtype) return self.xnp.cast(X, dtype) @ self.xnp.cast(self.A, dtype)
[docs] def to_dense(self): return self.A
[docs] class Triangular(Dense): """ Triangular Linear Operator.""" def __init__(self, A: Array, lower=True): super().__init__(A) self.lower = lower
[docs] class Sparse(LinearOperator): """ Sparse linear operator. Args: data (array_like): 1-D array representing the nonzero values of the sparse matrix. row_indices (array_like): 1-D array representing the row indices of the nonzero values. col_indices (array_like): 1-D array representing the column indices of the nonzero values. shape (tuple): Shape of the sparse matrix. Example: >>> data = jnp.array([1, 2, 3, 4, 5, 6]) >>> rol_indices = jnp.array([0, 0, 1, 2, 2, 2]) >>> col_indices = jnp.array([1, 3, 3, 0, 1, 2]) >>> shape = (3, 4) >>> op = Sparse(data, row_indices, col_indices, shape) """ def __init__(self, data, row_indices, col_indices, shape): super().__init__(dtype=data.dtype, shape=shape) xnp = self.xnp indx = xnp.argsort(row_indices) self.data = data[indx] self.row_indices = row_indices[indx] self.col_indices = col_indices[indx] A = coo_array((xnp.to_np(self.data), (xnp.to_np(self.row_indices), xnp.to_np(self.col_indices))), shape=shape).tocsr() row_pointers = xnp.array(A.indptr, dtype=xnp.int32, device=data.device) indices = xnp.array(A.indices, dtype=xnp.int32, device=data.device) self.A = xnp.sparse_csr(row_pointers, indices, self.data, shape) def _matmat(self, V): return self.A @ V def _rmatmat(self, V): return (self.T @ V.T).T
[docs] class ScalarMul(LinearOperator): """ Linear Operator representing scalar multiplication""" def __init__(self, c, shape, dtype=None, device=None): super().__init__(dtype=dtype or type(c), shape=shape) self.c = self.xnp.array(c, dtype=dtype, device=device) self.device = device # self.ensure_const_register_as_array() # def ensure_const_register_as_array(self): # self._args = (self.c, ) # self._kwargs = {"dtype": self.dtype, "shape": self.shape} def _matmat(self, v): return self.c * v def __str__(self): return f"{self.c}"
[docs] class Identity(LinearOperator): """ Linear Operator representing the identity matrix. Can also be created from I_like(A) Args: shape (tuple): Shape of the identity matrix. dtype: Data type of the identity matrix. Example: >>> shape = (3, 3) >>> dtype = jnp.float64 >>> op = Identity(shape, dtype) """ def __init__(self, shape, dtype): super().__init__(dtype=dtype, shape=shape) def __str__(self): return "I" def _matmat(self, X): return X
[docs] def to(self, device): self.device = device return self
[docs] def I_like(A: LinearOperator) -> Identity: """ A function that produces an Identity operator with the same shape, dtype and device as A """ Op = Identity(dtype=A.dtype, shape=A.shape) Op.to(A.device) return Op
[docs] @parametric class Product(LinearOperator): """ Matrix Multiply Product of Linear ops """ def __init__(self, *Ms): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) devices = [M.device for M in self.Ms] assert all(x == devices[0] for x in devices), "There is a device mismatch in Product" for M1, M2 in zip(Ms[:-1], Ms[1:]): if M1.shape[-1] != M2.shape[-2]: raise ValueError(f"dimension mismatch {M1.shape} vs {M2.shape}") shape = (Ms[0].shape[-2], Ms[-1].shape[-1]) dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) self.device = devices[0] def _matmat(self, v): for M in self.Ms[::-1]: v = M @ v return v def _rmatmat(self, v): for M in self.Ms: v = v @ M return v def __str__(self): return "".join(str(M) for M in self.Ms)
[docs] @parametric class Sum(LinearOperator): """ Sum of Linear ops """ def __init__(self, *Ms): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) devices = [M.device for M in self.Ms] assert all(x == devices[0] for x in devices), "There is a device mismatch in Sum" shape = Ms[0].shape for M in Ms: if M.shape != shape: raise ValueError(f"dimension mismatch {M.shape} vs {shape}") dtype = Ms[0].dtype super().__init__(dtype, shape) self.device = devices[0] def _matmat(self, v): return sum(M @ v for M in self.Ms) def _rmatmat(self, v): return sum(v @ M for M in self.Ms) def __str__(self): if len(self.Ms) > 5: return "Sum({}...)".format(", ".join(str(M) for M in self.Ms[:2])) return "+".join(str(M) for M in self.Ms)
def product(c): return reduce(lambda a, b: a * b, c)
[docs] @parametric class Kronecker(LinearOperator): """ Kronecker product of linear ops Kronecker([M1,M2]):= M1โŠ—M2 Args: *Ms (array_like): Sequence of linear operators representing the Kronecker product operands. Example: >>> M1 = jnp.array([[1, 2], [3, 4]]) >>> M2 = jnp.array([[5, 6], [7, 8]]) >>> op = Kronecker(M1, M2) """ def __init__(self, *Ms): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) shape = product([Mi.shape[-2] for Mi in Ms]), product([Mi.shape[-1] for Mi in Ms]) dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) def _matmat(self, v): ev = v.reshape(*[Mi.shape[-1] for Mi in self.Ms], -1) for i, M in enumerate(self.Ms): ev_front = self.xnp.moveaxis(ev, i, 0) shape = M.shape[0], *ev_front.shape[1:] Mev_front = (M @ ev_front.reshape(M.shape[-1], -1)).reshape(shape) ev = self.xnp.moveaxis(Mev_front, 0, i) return ev.reshape(self.shape[-2], ev.shape[-1]) def to_dense(self): Ms = [M.to_dense() if isinstance(M, LinearOperator) else M for M in self.Ms] return reduce(self.xnp.kron, Ms) def __str__(self): return "โŠ—".join(str(M) for M in self.Ms)
def kronsum(A, B): xnp = get_library_fns(A.dtype) device = xnp.get_device(A) IA = xnp.eye(A.shape[-2], A.shape[-2], dtype=A.dtype, device=device) IB = xnp.eye(B.shape[-2], B.shape[-2], dtype=B.dtype, device=device) return xnp.kron(A, IB) + xnp.kron(IA, B)
[docs] @parametric class KronSum(LinearOperator): """ Kronecker Sum Linear Operator, KronSum(A,B):= A โŠ• B = A โŠ— I + I โŠ— B Args: *Ms (array_like): Sequence of matrices representing the Kronecker sum operands. Example: >>> M1 = jnp.array([[1, 2], [3, 4]]) >>> M2 = jnp.array([[5, 6], [7, 8]]) >>> op = KronSum(M1, M2) """ def __init__(self, *Ms): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) shape = product([Mi.shape[-2] for Mi in Ms]), product([Mi.shape[-1] for Mi in Ms]) dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) def _matmat(self, v): ev = v.reshape(*[Mi.shape[-1] for Mi in self.Ms], -1) out = 0 * ev xnp = self.xnp for i, M in enumerate(self.Ms): ev_front = xnp.moveaxis(ev, i, 0) Mev_front = (M @ ev_front.reshape(M.shape[-1], -1)).reshape(M.shape[0], *ev_front.shape[1:]) out += xnp.moveaxis(Mev_front, 0, i) return out.reshape(self.shape[-2], ev.shape[-1]) def to_dense(self): Ms = [M.to_dense() if isinstance(M, LinearOperator) else M for M in self.Ms] return reduce(kronsum, Ms) def __str__(self): return "โŠ•โ‚–".join(str(M) for M in self.Ms)
[docs] @parametric class BlockDiag(LinearOperator): """ Block Diagonal Linear Operator. BlockDiag([A,B]):= [A 0; 0 B] Args: *Ms (array_like): Sequence of matrices representing the blocks. multiplicities (list, optional): List of integers representing the multiplicities of the corresponding blocks in *Ms. Default is None, which assigns a multiplicity of 1 to each block. Example: >>> M1 = jnp.array([[1, 2], [3, 4]]) >>> M2 = jnp.array([[5, 6], [7, 8]]) >>> op = BlockDiag(M1, M2, multiplicities=[2, 3]) """ def __init__(self, *Ms, multiplicities=None): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) self.multiplicities = [1 for _ in Ms] if multiplicities is None else multiplicities shape = (sum(Mi.shape[-2] * c for Mi, c in zip(Ms, self.multiplicities)), sum(Mi.shape[-1] * c for Mi, c in zip(Ms, self.multiplicities))) dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) def _matmat(self, v): # (n,k) # n = v.shape[0] k = v.shape[1] if len(v.shape) > 1 else 1 i = 0 y = [] for M, multiplicity in zip(self.Ms, self.multiplicities): i_end = i + multiplicity * M.shape[-1] elems = M @ v[i:i_end].T.reshape(k * multiplicity, M.shape[-1]).T y.append(elems.T.reshape(k, multiplicity * M.shape[0]).T) i = i_end y = self.xnp.concat(y, axis=0) # concatenate over rep axis return y def to_dense(self): Ms_all = [M for M, c in zip(self.Ms, self.multiplicities) for _ in range(c)] Ms_all = [Mi.to_dense() if isinstance(Mi, LinearOperator) else Mi for Mi in Ms_all] return self.xnp.block_diag(*Ms_all) def __str__(self): if len(self.Ms) > 5: return "BlockDiag({}...)".format(", ".join(str(M) for M in self.Ms[:2])) return "โŠ•".join(str(M) for M in self.Ms)
[docs] class Diagonal(LinearOperator): """ Diagonal LinearOperator. O(n) time and space matmuls. Args: diag (array_like): 1-D array representing the diagonal elements of the matrix. Example: >>> d = jnp.array([1, 2, 3]) >>> op = Diagonal(d) """ def __init__(self, diag): assert len(diag.shape) == 1, f"diagonal is not a vector, it is of shape {diag.shape=}" self.diag = diag super().__init__(dtype=diag.dtype, shape=(len(diag), ) * 2) def _matmat(self, X: Array) -> Array: return self.diag[:, None] * X def _rmatmat(self, X: Array) -> Array: return self.diag[None, :] * X
[docs] def to_dense(self): return self.xnp.diag(self.diag)
def __str__(self): return f"diag({self.diag})"
[docs] class Tridiagonal(LinearOperator): """ Tridiagonal linear operator. O(n) time and space matmuls. Args: alpha (array_like): 1-D array representing lower band of the operator. beta (array_like): 1-D array representing diagonal of the operator. gamma (array_like): 1-D array representing upper band of the operator. """ def __init__(self, alpha: Array, beta: Array, gamma: Array): alpha, beta = ensure_vec_is_matrix(alpha), ensure_vec_is_matrix(beta) gamma = ensure_vec_is_matrix(gamma) self.alpha, self.beta, self.gamma = alpha, beta, gamma super().__init__(dtype=beta.dtype, shape=(beta.shape[0], beta.shape[0])) def _matmat(self, X: Array) -> Array: xnp = self.xnp output = self.beta * X zeros = xnp.zeros(shape=(1, X.shape[-1]), dtype=X.dtype, device=xnp.get_device(X)) aux_gamma = xnp.concat([self.gamma * X[1:], zeros], axis=0) zeros = xnp.zeros(shape=(1, X.shape[-1]), dtype=X.dtype, device=xnp.get_device(X)) aux_alpha = xnp.concat([zeros, self.alpha * X[:-1]], axis=0) return output + aux_alpha + aux_gamma
def ensure_vec_is_matrix(vec): if len(vec.shape) == 1: vec = vec.reshape(-1, 1) return vec
[docs] @parametric class Transpose(LinearOperator): """ Transpose of a Linear Operator""" def __init__(self, A): self.A = A super().__init__(dtype=A.dtype, shape=(A.shape[1], A.shape[0])) self.device = A.device def _matmat(self, x): return self.A._rmatmat(x.T).T def _rmatmat(self, x): return self.A._matmat(x.T).T def __str__(self): return f"{str(self.A)}แต€"
[docs] @parametric class Adjoint(LinearOperator): """ Complex conjugate transpose of a Linear Operator (aka adjoint)""" def __init__(self, A): self.A = A super().__init__(dtype=A.dtype, shape=(A.shape[1], A.shape[0])) self.device = A.device def _matmat(self, x): return self.xnp.conj(self.A._rmatmat(self.xnp.conj(x).T)).T def _rmatmat(self, x): return self.xnp.conj(self.A._matmat(self.xnp.conj(x).T)).T def __str__(self): return f"{str(self.A)}*"
[docs] @parametric class Sliced(LinearOperator): """ Slicing of another linear operator A. Equivalent to A[slices[0], :][:, slices[1]] """ def __init__(self, A, slices): self.A = A self.slices = slices slices = tuple([sl.cpu() if hasattr(sl, "device") else sl for sl in slices]) new_shape = np.arange(A.shape[0])[slices[0]].shape + np.arange(A.shape[1])[slices[1]].shape super().__init__(dtype=A.dtype, shape=new_shape) self.device = A.device def _matmat(self, X: Array) -> Array: xnp = self.xnp start_slices, end_slices = self.slices device = xnp.get_device(X) Y = xnp.zeros(shape=(self.A.shape[-1], X.shape[-1]), dtype=self.dtype, device=device) Y = xnp.update_array(Y, X, end_slices) output = self.A @ Y return output[start_slices] def _rmatmat(self, X: Array) -> Array: xnp = self.xnp start_slices, end_slices = self.slices device = xnp.get_device(X) Y = xnp.zeros(shape=(X.shape[0], self.A.shape[0]), dtype=self.dtype, device=device) Y = xnp.update_array(Y, X, ..., start_slices) output = Y @ self.A return output[..., end_slices] def __str__(self): has_length = hasattr(self.slices[0], '__len__') if has_length: has_many = (len(self.slices[0]) > 5 or len(self.slices[1]) > 5) if has_many: return f"{str(self.A)}[slc1, slc2]" return f"{str(self.A)}[{self.slices[0]},{self.slices[1]}]"
[docs] class Jacobian(LinearOperator): """ Jacobian (linearization) of a function f: R^n -> R^m at point x. Matrix has shape (m, n) Args: f (callable): Function representing the mapping from R^n to R^m. x (array_like): 1-D array representing the point at which to compute the Jacobian. Example: >>> def f(x): ... return jnp.array([x[0]**2, x[1]**3, jnp.sin(x[2])]) >>> x = jnp.array([1, 2, 3]) >>> op = Jacobian(f, x) """ def __init__(self, f, x): self.f = f self.x = x # could perhaps relax this with automatic reshaping of x and y # assert len(x.shape) == 1, "x must be a vector" y_shape = f(x).shape # assert len(y_shape) == 1, "y must be a vector" super().__init__(dtype=x.dtype, shape=(y_shape[0], x.shape[0])) def _matmat(self, X): fn = self.xnp.vmap(partial(self.xnp.jvp_derivs, self.f, (self.x, ))) out = fn((X.T, )).T if self.xnp.__name__ == 'cola.torch_fns': # pytorch converts to double silently out = out.to(dtype=self.dtype) return out def _rmatmat(self, X): def vjp(v): return self.xnp.vjp_derivs(self.f, (self.x, ), v) fn = self.xnp.vmap(vjp) out = fn(X)[0] if self.xnp.__name__ == 'cola.torch_fns': # pytorch converts to double silently out = out.to(dtype=self.dtype) return out def __str__(self): return "J"
[docs] class Hessian(LinearOperator): """ Hessian of a scalar function f: R^n -> R at point x. Matrix has shape (n, n) Args: f (callable): Function representing the mapping from R^n to R. x (array_like): 1-D array representing the point at which to compute the Hessian. Example: >>> def f(x): ... return x[1]**3+np.sin(x[2]) >>> x = jnp.array([1, 2, 3]) >>> op = Hessian(f, x) """ def __init__(self, f, x): self.f, self.x = f, x assert len(x.shape) == 1, "x must be a vector" super().__init__(dtype=x.dtype, shape=(x.shape[0], x.shape[0])) def _matmat(self, X): xnp = self.xnp mvm = partial(xnp.jvp_derivs, xnp.grad(self.f), (self.x, ), create_graph=False) out = xnp.vmap(mvm)((X.T, )).T return out def __str__(self): return "H"
[docs] class Permutation(LinearOperator): """ Permutation matrix. Args: perm (array_like): 1-D array representing the permutation. dtype (optional): specify the dtype to operate on (not int) Example: >>> P = Permutation(np.array([1, 0, 3, 2])) """ def __init__(self, perm, dtype=None): self.perm = perm fns = get_library_fns(self.perm.dtype) # Need to map dtype back to float dtype = fns.float32 if dtype is None else dtype super().__init__(dtype=dtype, shape=(len(perm), len(perm))) def _matmat(self, v): return v[self.perm]
[docs] @parametric class Concatenated(LinearOperator): """ Produces a linear operator equivalent to concatenating a collection of matrices Ms along specified axis Args: *Ms (array_like): Sequence of matrices representing the blocks. axis (int, optional): specify which axis to concatenate on (0 or 1) Example: >>> M1 = jnp.array([[1, 2], [3, 4]]) >>> M2 = jnp.array([[5, 6], [7, 8]]) >>> A = Concatenated(M1, M2, axis=1) >>> A.shape >>> (2,4) """ def __init__(self, *Ms, axis=0): self.Ms = Ms assert all(M.shape[axis] == Ms[0].shape[axis] for M in Ms), \ f"Trying to concatenate matrices of different sizes {[M.shape for M in Ms]}" concat_size = sum(M.shape[axis] for M in Ms) shape = (Ms[0].shape[0], concat_size) if axis == 1 else (concat_size, Ms[0].shape[1]) self.axis = axis super().__init__(Ms[0].dtype, shape) def _matmat(self, V): return self.xnp.concat([M @ V for M in self.Ms], axis=self.axis)
[docs] class ConvolveND(LinearOperator): """ n-Dimensional convolution Linear operator (only works in jax right now.) """ def __init__(self, filter, array_shape, mode='same'): assert filter.dtype in [np.float32, np.float64], "Only supporting jax right now" self.filter = filter self.array_shape = array_shape assert mode == 'same' import jax.numpy as jnp super().__init__(dtype=filter.dtype, shape=(np.prod(array_shape), jnp.prod(array_shape))) self.conv = self.xnp.vmap(partial(self.xnp.convolve, in2=filter, mode=mode)) def _matmat(self, X): Z = X.T.reshape(X.shape[-1], *self.array_shape) return self.conv(Z).reshape(X.shape[-1], -1).T
[docs] class Householder(LinearOperator): """ Householder rotation matrix.""" def __init__(self, vec, beta=2.): super().__init__(shape=(vec.shape[-2], vec.shape[-2]), dtype=vec.dtype) self.vec = vec self.beta = self.xnp.array(beta, dtype=vec.dtype, device=self.device) def _matmat(self, X: Array) -> Array: xnp = self.xnp angle = xnp.sum(X * xnp.conj(self.vec), axis=-2, keepdims=True) out = X - self.beta * angle * self.vec return out
[docs] class Kernel(LinearOperator): """ Kernel operator based on a given function f where the matvec is evaluated on the fly. That is, [Kv]_i = \\sum_{j} f(x1_i, x2_j) v_j. The variables block_size1 and block_size2 determine the memory usage of the matvec and matmat operations. Args: x1 (array): N-D array x2 (array): N-D array fn (callable): function that defines the kernel block_size1 (int): block size for x1 block_size2 (int): block size for x2 """ def __init__(self, x1, x2, fn, block_size1, block_size2): self.x1 = x1 self.x2 = x2 self.fn = fn self.block_size1 = block_size1 self.block_size2 = block_size2 super().__init__(dtype=x1.dtype, shape=(x1.shape[0], x2.shape[0])) self.iters1 = self.shape[0] // block_size1 self.iters2 = self.shape[1] // block_size2 def _matmat(self, V): xnp = self.xnp out = xnp.zeros(shape=V.shape, dtype=V.dtype, device=V.device) for idx in range(self.iters1): fit1 = None if idx + 1 == self.iters1 else (idx + 1) * self.block_size1 loc1 = slice(idx * self.block_size1, fit1) update = xnp.zeros(shape=(self.x1[loc1].shape[0], V.shape[1]), dtype=V.dtype, device=V.device) for jdx in range(self.iters2): fit2 = None if jdx + 1 == self.iters2 else (jdx + 1) * self.block_size2 loc2 = slice(jdx * self.block_size2, fit2) update += self.fn(self.x1[loc1], self.x2[loc2]) @ V[loc2] out = xnp.update_array(out, update, loc1) return out def __str__(self): return "Ker(x1, x2, fn)"
[docs] class FFT(LinearOperator): """ FFT matrix. Uses convention so matrix is unitary.""" def __init__(self, n, dtype=None): super().__init__(shape=(n, n), dtype=dtype, annotations={cola.Unitary}) def _matmat(self, X): return self.xnp.fft(X, axis=0, norm='ortho') def _rmatmat(self, X): return self.xnp.ifft(X.conj(), axis=1, norm='ortho').conj()
def FIM(logits_fn, theta): """ Fisher information matrix for a probability model log p(y|theta) where p is a classifier probability distribution. Averages over batch dimensions. Args: logit_fn function that maps parameters to logits of shape (*, n_classes) theta (array_like): parameter vector to eval Fisher at Returns: Hessian(KL(p(y|theta')||p(y|theta))) (w.r.t. theta) """ xnp = get_library_fns(theta.dtype) probs = xnp.softmax(logits_fn(theta), axis=-1) def entropy(theta): log_probs = xnp.log_softmax(logits_fn(theta), axis=-1) return -xnp.sum(probs * log_probs, axis=-1).mean()