{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sharp Bits\n", "As you have seen in the previous sections CoLA has many useful features for doing large-scale linear algebra operations such as being compatible with PyTorch and JAX or automatically selecting an efficient algorithm for a problem through dispatch. However, all these features come with some unintuitive consequences which we now explain.\n", "\n", "## Balancing JAX and PyTorch differences\n", "\n", "### Dynamic slicing\n", "In CoLA we build code that runs in both `JAX` and `PyTorch` and therefore we must adhere to several particular rules from each of the frameworks. For example, when using `jit` or `vmap` on `JAX` we cannot return variable sized arrays. This has the negative consequence that iterative algorithms like Lanczos or Arnoldi would have memory requirements proportional to the `max_iters`. This is the case as even when the algorithm terminates much earlier than `max_iters`, we cannot return `array[:idx]` where `idx` stands for the current iteration.\n", "\n", "### GPU allocation\n", "Additionally, `PyTorch` and `JAX` take completely different strategies for allocating arrays into a GPU. `PyTorch` expects that the users will explicitly allocate an array via `array.to(device)` or `array.cuda()` where, by default, the `array` would be created on CPU. In contrast, `JAX` implicitly allocates any `array` into a GPU if a GPU happens to be available. In principle this is not a huge problem for CoLA, as any `LinearOperator` would state that it is in the device of the `parameters` / `arrays` that originated it. However, the situation becomes more involved when we combine `LinearOperators` that might not be in the same device. For example, imagine that you have `D = cola.ops.Diagonal(torch.ones(N).cuda())`, `L = cola.ops.Tridiagonal(torch.ones(N-1), torch.ones(N), torch.ones(N-1))` and that you want to combine both as `A = 3 * D + L`. This would not through any error, but as soon as you do `A @ v` you would get a device allocation mismatch. Fortunately, in CoLA we can easily circumvent this issue by using `A.to(device)` as that would pass each of the paramters of the `LinearOperator`, namely the `3` and the four `ones` `arrays` into the device of our choice.\n", "\n", "## Iterative algorithms\n", "\n", "### Algorithm's parameters and defaults\n", "CoLA has several iterative algorithms like CG, GMRES, Lanczos, Arnoldi, LOBPCG or SLQ. This means that each algorithm has its own set of parameters which might not have the same interpretation for each of them. For example, in the context of CG, the tolerance parameter `tol` is about how small the residual of the linear system is, $||A x_t - b||$. However, in GMRES, `tol` refers to the tolerance used as stopping criteria used in the inner call to Arnoldi and therefore, in that context, `tol` indicates how good is the Arnoldi decomposition of the operator. It is thus important to read the [algorithms API](https://cola.readthedocs.io/en/latest/package/cola.algorithms.html) to understand the relevance of each parameter. \n", "\n", "Moreover, we have selected a set of defaults for each algorithm that have shown good performance across the applications that were we have used CoLA. Yet, a good default for an algorithm in a certain application might not translate into a good default for another application. For example, when training Neural PDEs it is common to set the tolerance of CG to `tol=1e-7`. In contrast, when using CG to train Gaussian processes the tolerances used are on the order of `tol=1` and asking for a lower tolerance does not result in any clear performance benefit.\n", "\n", "### Relative tolerances\n", "Currently CoLA uses absolute tolerances as stopping criteria. This is not ideal as setting a small tolerance in a large problem might be an unreasonable stopping criteria to meet and therefore we are guaranteeing that the algorithm will always run its `max_iters`. We are on the process of changing the tolerances to a relative criteria.\n", "\n", "### Conditioning of linear systems\n", "The number of iterations that it takes for an iterative solver like CG or GMRES usually scales with the conditioning of the underlying matrix. This means that even for two matrices of the same size, if one has a bad condition number and the other does not, then for the former matrix the iterative solver would take many iterations to converge and just a few for the latter case. \n", "\n", "## Tracking dispatch\n", "Right now it is not immediate know what dispatch rules or algorithms are being called on an complex operator that combines different structures at different sizes. For example, when decomposing the operator one part might call a dense solver while the remaining parts call an iterative solver. We already have a preliminary tracker that logs which dispatch rules are called but we are improving this tool.\n", "\n", "## Coming Soon\n", "Below is a list of features that we are currently working on incorporating.\n", "#### Dtype allocation\n", "Just as we do `A.to(device)` to place the `LinearOperator` into the `device` of our choice, we are also currently working on having the `to()` method also place the `LinearOperator` into the `dtype` of our choice in order to match the behaviour that we have in `PyTorch`. \n", "\n", "### Changing algorithm's defaults\n", "Given that a good default for an algorithm in a certain application might not be a good default for another, we are planning on adding a feature to allow the user to change the defaults of certain algorithms to a different value. We are thinking of exploring something similar to `matplotlib.pyplot.rc`." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }