Sharp Bits
As you have seen in the previous sections CoLA has many useful features for doing large-scale linear algebra operations such as being compatible with PyTorch and JAX or automatically selecting an efficient algorithm for a problem through dispatch. However, all these features come with some unintuitive consequences which we now explain.
Balancing JAX and PyTorch differences
Dynamic slicing
In CoLA we build code that runs in both JAX
and PyTorch
and therefore we must adhere to several particular rules from each of the frameworks. For example, when using jit
or vmap
on JAX
we cannot return variable sized arrays. This has the negative consequence that iterative algorithms like Lanczos or Arnoldi would have memory requirements proportional to the max_iters
. This is the case as even when the algorithm terminates much earlier than max_iters
, we cannot return
array[:idx]
where idx
stands for the current iteration.
GPU allocation
Additionally, PyTorch
and JAX
take completely different strategies for allocating arrays into a GPU. PyTorch
expects that the users will explicitly allocate an array via array.to(device)
or array.cuda()
where, by default, the array
would be created on CPU. In contrast, JAX
implicitly allocates any array
into a GPU if a GPU happens to be available. In principle this is not a huge problem for CoLA, as any LinearOperator
would state that it is in the device of
the parameters
/ arrays
that originated it. However, the situation becomes more involved when we combine LinearOperators
that might not be in the same device. For example, imagine that you have D = cola.ops.Diagonal(torch.ones(N).cuda())
, L = cola.ops.Tridiagonal(torch.ones(N-1), torch.ones(N), torch.ones(N-1))
and that you want to combine both as A = 3 * D + L
. This would not through any error, but as soon as you do A @ v
you would get a device allocation mismatch.
Fortunately, in CoLA we can easily circumvent this issue by using A.to(device)
as that would pass each of the paramters of the LinearOperator
, namely the 3
and the four ones
arrays
into the device of our choice.
Iterative algorithms
Algorithm’s parameters and defaults
CoLA has several iterative algorithms like CG, GMRES, Lanczos, Arnoldi, LOBPCG or SLQ. This means that each algorithm has its own set of parameters which might not have the same interpretation for each of them. For example, in the context of CG, the tolerance parameter tol
is about how small the residual of the linear system is, \(||A x_t - b||\). However, in GMRES, tol
refers to the tolerance used as stopping criteria used in the inner call to Arnoldi and therefore, in that context,
tol
indicates how good is the Arnoldi decomposition of the operator. It is thus important to read the algorithms API to understand the relevance of each parameter.
Moreover, we have selected a set of defaults for each algorithm that have shown good performance across the applications that were we have used CoLA. Yet, a good default for an algorithm in a certain application might not translate into a good default for another application. For example, when training Neural PDEs it is common to set the tolerance of CG to tol=1e-7
. In contrast, when using CG to train Gaussian processes the tolerances used are on the order of tol=1
and asking for a lower
tolerance does not result in any clear performance benefit.
Relative tolerances
Currently CoLA uses absolute tolerances as stopping criteria. This is not ideal as setting a small tolerance in a large problem might be an unreasonable stopping criteria to meet and therefore we are guaranteeing that the algorithm will always run its max_iters
. We are on the process of changing the tolerances to a relative criteria.
Conditioning of linear systems
The number of iterations that it takes for an iterative solver like CG or GMRES usually scales with the conditioning of the underlying matrix. This means that even for two matrices of the same size, if one has a bad condition number and the other does not, then for the former matrix the iterative solver would take many iterations to converge and just a few for the latter case.
Tracking dispatch
Right now it is not immediate know what dispatch rules or algorithms are being called on an complex operator that combines different structures at different sizes. For example, when decomposing the operator one part might call a dense solver while the remaining parts call an iterative solver. We already have a preliminary tracker that logs which dispatch rules are called but we are improving this tool.
Coming Soon
Below is a list of features that we are currently working on incorporating. #### Dtype allocation Just as we do A.to(device)
to place the LinearOperator
into the device
of our choice, we are also currently working on having the to()
method also place the LinearOperator
into the dtype
of our choice in order to match the behaviour that we have in PyTorch
.
Changing algorithm’s defaults
Given that a good default for an algorithm in a certain application might not be a good default for another, we are planning on adding a feature to allow the user to change the defaults of certain algorithms to a different value. We are thinking of exploring something similar to matplotlib.pyplot.rc
.