"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"- This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited\n",
"- Specifically, it focuses on cases where you saved the model using `torch.save(model.state_dict(), \"model.pth\")` (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning\n",
"- While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs"
"- Here, we use the \"large\" GPT-2 model to make things more interesting (you may use the \"gpt2-small (124M)\" to lower the memory requirements and execution time of this notebook)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "tMuhCYaVI0w7"
},
"outputs": [],
"source": [
"from previous_chapters import GPTModel\n",
"# If the `previous_chapters.py` file is not available locally,\n",
"# you can import it from the `llms-from-scratch` PyPI package.\n",
"# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
"# E.g.,\n",
"# from llms_from_scratch.ch04 import GPTModel\n",
"- Notice that the memory is 2x as large as in the previous session\n",
"- This is because we have the same model in memory twice, for a short period of time:\n",
" - The first time via `model.to(device)`\n",
" - The second time via the code line `model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))`; eventually, the loaded model weights will be copied into the model, and the `state_dict` will be discarded, but for a brief amount of time, we have both the main model and the loaded `state_dict` in memory\n",
"- The remaining sections focus on addressing this\n",
"- But first, let's test the model and reset the GPU memory\n"
"- So, as peak memory is concerned, it doesn't make a difference whether we instantiate the model on the device first and then use `map_location=\"device\"` or load the weights into CPU memory first (`map_location=\"cpu\"`) and then move it to the device"
],
"metadata": {
"id": "UGjBD6GASS_y"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "RdPnW3iLLrjX"
},
"source": [
" \n",
"## 4. Loading weights sequentially"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FYqtUON602TD"
},
"source": [
"- One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially\n",
"- Below, we:\n",
" - first load the model into GPU memory\n",
" - then load the model weights into CPU memory\n",
" - and finally copy each parameter one by one into GPU memory\n"
"# Sequentially copy weights to the model's parameters\n",
"with torch.no_grad():\n",
" for name, param in model.named_parameters():\n",
" if name in state_dict:\n",
" param.copy_(state_dict[name].to(device))\n",
" else:\n",
" print(f\"Warning: {name} not found in state_dict.\")\n",
"\n",
"print_memory_usage()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pn9xD_xL1ZzM"
},
"source": [
"- As we can see above, the memory usage is much lower than before\n",
"- Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using `\".to\"` the model)\n",
"- Overall, this is a significant improvement\n",
"- Again, let's briefly test the model and then reset the GPU memory for the next section"
"- In the previous session, we reduced GPU memory use by loading the weights (`state_dict`) into CPU memory first before copying them one-by-one into the model\n",
"- However, what do we do if we have limited CPU memory?\n",
"- This section uses PyTorch's so-called `\"meta\"` device approach to load a model on machines with large GPU memory but small CPU memory\n",
"- But first, let's define a convenience function to monitor CPU memory"
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UWrmnCML5oKy"
},
"source": [
"- Now, suppose we have a machine with low CPU memory but large GPU memory\n",
"- We can trade off CPU memory and GPU memory usage by introducing PyTorch's so-called \"meta\" device\n",
"- PyTorch's meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating \"meta\" tensors\n",
"- This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation"
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VpnCABp75-VQ"
},
"source": [
"- As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements\n",
"- One might ask: \"Is the sequential weight loading still necessary then, and how does that compare to the original approach?\"\n",
"- Let's check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):"
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NKAjxbX86xnb"
},
"source": [
"- As we can see above, the \"simple\" weight loading without the meta device uses more memory\n",
"- In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvDVFpcaRISr"
},
"source": [
" \n",
"## 6. Using `mmap=True` (recommmended)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w3H5gPygRISr"
},
"source": [
"- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n",
"- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n",
"- Also, see the helpful comment by [mikaylagawarecki](https://github.com/rasbt/LLMs-from-scratch/issues/402)\n",
"- At first glance, it may look less efficient than the sequential approaches above:"