Interactive online version: Open In Colab

Sharp Bits

As you have seen in the previous sections CoLA has many useful features for doing large-scale linear algebra operations such as being compatible with PyTorch and JAX or automatically selecting an efficient algorithm for a problem through dispatch. However, all these features come with some unintuitive consequences which we now explain.

Balancing JAX and PyTorch differences

Dynamic slicing

In CoLA we build code that runs in both JAX and PyTorch and therefore we must adhere to several particular rules from each of the frameworks. For example, when using jit or vmap on JAX we cannot return variable sized arrays. This has the negative consequence that iterative algorithms like Lanczos or Arnoldi would have memory requirements proportional to the max_iters. This is the case as even when the algorithm terminates much earlier than max_iters, we cannot return array[:idx] where idx stands for the current iteration.

GPU allocation

Additionally, PyTorch and JAX take completely different strategies for allocating arrays into a GPU. PyTorch expects that the users will explicitly allocate an array via array.to(device) or array.cuda() where, by default, the array would be created on CPU. In contrast, JAX implicitly allocates any array into a GPU if a GPU happens to be available. In principle this is not a huge problem for CoLA, as any LinearOperator would state that it is in the device of the parameters / arrays that originated it. However, the situation becomes more involved when we combine LinearOperators that might not be in the same device. For example, imagine that you have D = cola.ops.Diagonal(torch.ones(N).cuda()), L = cola.ops.Tridiagonal(torch.ones(N-1), torch.ones(N), torch.ones(N-1)) and that you want to combine both as A = 3 * D + L. This would not through any error, but as soon as you do A @ v you would get a device allocation mismatch. Fortunately, in CoLA we can easily circumvent this issue by using A.to(device) as that would pass each of the paramters of the LinearOperator, namely the 3 and the four ones arrays into the device of our choice.

Iterative algorithms

Algorithm’s parameters and defaults

CoLA has several iterative algorithms like CG, GMRES, Lanczos, Arnoldi, LOBPCG or SLQ. This means that each algorithm has its own set of parameters which might not have the same interpretation for each of them. For example, in the context of CG, the tolerance parameter tol is about how small the residual of the linear system is, \(||A x_t - b||\). However, in GMRES, tol refers to the tolerance used as stopping criteria used in the inner call to Arnoldi and therefore, in that context, tol indicates how good is the Arnoldi decomposition of the operator. It is thus important to read the algorithms API to understand the relevance of each parameter.

Moreover, we have selected a set of defaults for each algorithm that have shown good performance across the applications that were we have used CoLA. Yet, a good default for an algorithm in a certain application might not translate into a good default for another application. For example, when training Neural PDEs it is common to set the tolerance of CG to tol=1e-7. In contrast, when using CG to train Gaussian processes the tolerances used are on the order of tol=1 and asking for a lower tolerance does not result in any clear performance benefit.

Relative tolerances

Currently CoLA uses absolute tolerances as stopping criteria. This is not ideal as setting a small tolerance in a large problem might be an unreasonable stopping criteria to meet and therefore we are guaranteeing that the algorithm will always run its max_iters. We are on the process of changing the tolerances to a relative criteria.

Conditioning of linear systems

The number of iterations that it takes for an iterative solver like CG or GMRES usually scales with the conditioning of the underlying matrix. This means that even for two matrices of the same size, if one has a bad condition number and the other does not, then for the former matrix the iterative solver would take many iterations to converge and just a few for the latter case.

Tracking dispatch

Right now it is not immediate know what dispatch rules or algorithms are being called on an complex operator that combines different structures at different sizes. For example, when decomposing the operator one part might call a dense solver while the remaining parts call an iterative solver. We already have a preliminary tracker that logs which dispatch rules are called but we are improving this tool.

Coming Soon

Below is a list of features that we are currently working on incorporating. #### Dtype allocation Just as we do A.to(device) to place the LinearOperator into the device of our choice, we are also currently working on having the to() method also place the LinearOperator into the dtype of our choice in order to match the behaviour that we have in PyTorch.

Changing algorithm’s defaults

Given that a good default for an algorithm in a certain application might not be a good default for another, we are planning on adding a feature to allow the user to change the defaults of certain algorithms to a different value. We are thinking of exploring something similar to matplotlib.pyplot.rc.