Update memory efficient loading nb

This commit is contained in:
rasbt
2025-12-20 18:35:13 -06:00
parent 695ecb61ce
commit 2b9a67c00d

View File

@@ -43,7 +43,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "Ji9LlnMlRISm"
},
"source": [ "source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">" "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">"
] ]
@@ -56,14 +58,14 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "SxQzFoS-IXdY", "id": "SxQzFoS-IXdY",
"outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919" "outputId": "9f8fd57a-91e7-489d-d86e-656df536c604"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"torch version: 2.6.0\n" "torch version: 2.9.0+cu126\n"
] ]
} }
], ],
@@ -204,12 +206,12 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "GK3NEA3eJv3f", "id": "GK3NEA3eJv3f",
"outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013" "outputId": "434b51ca-7c8b-44dd-8a84-41ab48a290ff"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"Maximum GPU memory allocated: 6.4 GB\n" "Maximum GPU memory allocated: 6.4 GB\n"
] ]
@@ -292,12 +294,12 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "SqmTzztqKnTs", "id": "SqmTzztqKnTs",
"outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d" "outputId": "218332da-8b66-4169-d876-8d72c68691fc"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"Maximum GPU memory allocated: 0.0 GB\n" "Maximum GPU memory allocated: 0.0 GB\n"
] ]
@@ -315,7 +317,7 @@
}, },
"source": [ "source": [
"&nbsp;\n", "&nbsp;\n",
"## 3. Weight loading" "## 3. Basic weight loading"
] ]
}, },
{ {
@@ -336,12 +338,12 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "wCrQNbSJJO9w", "id": "wCrQNbSJJO9w",
"outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994" "outputId": "2623b399-bce6-4506-ec0b-c3c94729b80f"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"Maximum GPU memory allocated: 12.8 GB\n" "Maximum GPU memory allocated: 12.8 GB\n"
] ]
@@ -386,12 +388,12 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "DvlUn-nmmbuj", "id": "DvlUn-nmmbuj",
"outputId": "11d3ab68-f570-4c1e-c631-fe5547026799" "outputId": "7a9afbde-826f-4fb2-874d-feb6e8724834"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"Maximum GPU memory allocated: 0.0 GB\n" "Maximum GPU memory allocated: 0.0 GB\n"
] ]
@@ -409,6 +411,64 @@
"cleanup()" "cleanup()"
] ]
}, },
{
"cell_type": "markdown",
"source": [
"- Let's test another common pattern that is very popular in practice:"
],
"metadata": {
"id": "IQ531-IuRuzD"
}
},
{
"cell_type": "code",
"source": [
"start_memory_tracking()\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"\n",
"model.load_state_dict(\n",
" torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
")\n",
"model.to(device)\n",
"model.eval();\n",
"\n",
"print_memory_usage()"
],
"metadata": {
"id": "2m54kzX5RxLX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Test if the model works (no need to track memory here)\n",
"test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
"model.eval()\n",
"\n",
"with torch.no_grad():\n",
" model(test_input)\n",
"\n",
"del model, test_input, state_dict, param\n",
"cleanup()"
],
"metadata": {
"id": "XWvQTRN4R2CM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"- So, as peak memory is concerned, it doesn't make a difference whether we instantiate the model on the device first and then use `map_location=\"device\"` or load the weights into CPU memory first (`map_location=\"cpu\"`) and then move it to the device"
],
"metadata": {
"id": "UGjBD6GASS_y"
}
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
@@ -434,7 +494,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -486,7 +546,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -539,7 +599,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": null,
"metadata": { "metadata": {
"id": "BrcWy0q-3Bbe" "id": "BrcWy0q-3Bbe"
}, },
@@ -590,7 +650,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -648,7 +708,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -707,7 +767,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -754,7 +814,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "jvDVFpcaRISr"
},
"source": [ "source": [
"&nbsp;\n", "&nbsp;\n",
"## 6. Using `mmap=True` (recommmended)" "## 6. Using `mmap=True` (recommmended)"
@@ -762,7 +824,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "w3H5gPygRISr"
},
"source": [ "source": [
"- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n", "- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n",
"- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n", "- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n",
@@ -772,7 +836,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -808,7 +872,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "pGC0rBv4RISr"
},
"source": [ "source": [
"- The reason why the CPU RAM usage is so high is that there's enough CPU RAM available on this machine\n", "- The reason why the CPU RAM usage is so high is that there's enough CPU RAM available on this machine\n",
"- However, if you were to run this on a machine with limited CPU RAM, the `mmap` approach would use less memory" "- However, if you were to run this on a machine with limited CPU RAM, the `mmap` approach would use less memory"
@@ -816,7 +882,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "fd11QM8pRISr"
},
"source": [ "source": [
"&nbsp;\n", "&nbsp;\n",
"## 7. Other methods" "## 7. Other methods"
@@ -824,7 +892,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "0U2Y6eo8RISr"
},
"source": [ "source": [
"- This notebook is focused on simple, built-in methods for loading weights in PyTorch\n", "- This notebook is focused on simple, built-in methods for loading weights in PyTorch\n",
"- The recommended approach for limited CPU memory cases is the `mmap=True` approach explained enough\n", "- The recommended approach for limited CPU memory cases is the `mmap=True` approach explained enough\n",
@@ -833,7 +903,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"metadata": { "metadata": {
"id": "2CgPEZUIb00w" "id": "2CgPEZUIb00w"
}, },
@@ -855,7 +925,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": null,
"metadata": { "metadata": {
"colab": { "colab": {
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
@@ -908,12 +978,11 @@
"metadata": { "metadata": {
"accelerator": "GPU", "accelerator": "GPU",
"colab": { "colab": {
"gpuType": "L4", "gpuType": "T4",
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3",
"language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
@@ -930,5 +999,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 0
} }