{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Computing the eigenspectrum of the Hessian of a Neural Network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will consider computing the eigenvalues of the Hessian of the loss function for a ResNet18." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install -q timm\n", "%pip install -q detectors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First let's load a pretrained resnet18 model on CIFAR10 and verify that it is loaded correctly." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n", "Test accuracy: 94.98%\n" ] } ], "source": [ "import torch\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "import os\n", "import detectors\n", "import timm\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Load CIFAR10 dataset\n", "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n", "trainset = torchvision.datasets.CIFAR10(root=os.path.expanduser('~/datasets'), train=True, download=True, transform=transform)\n", "trainloader = torch.utils.data.DataLoader(trainset, batch_size=50, shuffle=True)\n", "testset = torchvision.datasets.CIFAR10(root=os.path.expanduser('~/datasets'), train=False, download=True, transform=transform)\n", "testloader = torch.utils.data.DataLoader(testset, batch_size=50, shuffle=False)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "# Load pretrained ResNet18 model and verify the results\n", "model = timm.create_model(\"resnet18_cifar10\", pretrained=True).to(device).eval()\n", "with torch.no_grad():\n", " correct = sum((model(images.to(device)).argmax(1) == labels.to(device)).sum().item() for images, labels in testloader)\n", "accuracy = 100 * correct / len(testset)\n", "print(f\"Test accuracy: {accuracy:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will define a function that computes the loss explicitly as a function of the parameters, so we can compute the Hessian of this function." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch.func as tf\n", "from torch.utils._pytree import tree_flatten, tree_unflatten\n", "\n", "# make stateless model\n", "def flatten_params(params):\n", " shapes = [p.shape for p in params]\n", " flat_params = torch.cat([p.flatten() for p in params])\n", " return flat_params, shapes\n", "\n", "\n", "def unflatten_params(flat_params, shapes):\n", " params = []\n", " i = 0\n", " for shape in shapes:\n", " size = torch.prod(torch.tensor(shape)).item()\n", " params.append(flat_params[i:i + size].view(shape))\n", " i += size\n", " return params\n", "\n", "flat_p, shape = flatten_params(list(model.parameters()))\n", "flat_p = flat_p.detach().requires_grad_(True)\n", "\n", "def stateless_model(flatparams, x):\n", " params = unflatten_params(flatparams, shape)\n", " names = list(n for n, _ in model.named_parameters())\n", " nps = {n: p for n, p in zip(names, params)}\n", " return tf.functional_call(model, nps, x)\n", "\n", "def flat_loss(X,y,params):\n", " return criterion(stateless_model(params, X).reshape(X.shape[0],-1), y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will define the Hessian of this loss function. Due to memory constraints we cannot compute the loss for the entire dataset simultaneously, so instead we need to loop over the elements in the dataloader. For this we will create a new linear operator `BatchedHessian`" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import cola\n", "from functools import partial\n", "\n", "class BatchedHessian(cola.ops.LinearOperator):\n", " def __init__(self, loss, params, dataloader):\n", " self.loss = loss\n", " self.params = params\n", " self.dataloader = dataloader\n", " super().__init__(dtype=params.dtype, shape=(params.numel(), params.numel()),\n", " annotations={cola.SelfAdjoint}) # mark it as self-adjoint\n", "\n", " def _matmat(self, V):\n", " HV = torch.zeros_like(V)\n", " with torch.no_grad():\n", " n = 0\n", " for X,y in self.dataloader:\n", " with torch.enable_grad():\n", " H = cola.ops.Hessian(partial(self.loss, X.to(self.device), y.to(self.device)), self.params)\n", " out = H@V\n", " n +=1\n", " HV += out\n", " return HV/n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "H = BatchedHessian(flat_loss, flat_p, testloader)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, the matrix is of considerable size because of the 10M+ parameters of the model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of Hessian: (11173962, 11173962)\n" ] } ], "source": [ "print(f\"size of Hessian: {H.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will use Lanczos to compute the largest 10 eigenvalues. Setting the maximum iterations to $30$ in order for the computation to only take ~15 minutes, but with more time we can get more eigenvalues." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:Non keyed randn used. To be deprecated soon.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ec7c068136e447a29d8a3cb697f12c07", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Running body_fun: 0%| | 0/100 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(eigs.cpu().data.numpy())\n", "plt.yscale('log')\n", "plt.ylabel('eigenvalues')\n", "plt.xlabel('index')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cola", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }