mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Update memory efficient loading nb
This commit is contained in:
@@ -43,7 +43,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "Ji9LlnMlRISm"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">"
|
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">"
|
||||||
]
|
]
|
||||||
@@ -56,14 +58,14 @@
|
|||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "SxQzFoS-IXdY",
|
"id": "SxQzFoS-IXdY",
|
||||||
"outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919"
|
"outputId": "9f8fd57a-91e7-489d-d86e-656df536c604"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"torch version: 2.6.0\n"
|
"torch version: 2.9.0+cu126\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -204,12 +206,12 @@
|
|||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "GK3NEA3eJv3f",
|
"id": "GK3NEA3eJv3f",
|
||||||
"outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013"
|
"outputId": "434b51ca-7c8b-44dd-8a84-41ab48a290ff"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"Maximum GPU memory allocated: 6.4 GB\n"
|
"Maximum GPU memory allocated: 6.4 GB\n"
|
||||||
]
|
]
|
||||||
@@ -292,12 +294,12 @@
|
|||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "SqmTzztqKnTs",
|
"id": "SqmTzztqKnTs",
|
||||||
"outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d"
|
"outputId": "218332da-8b66-4169-d876-8d72c68691fc"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"Maximum GPU memory allocated: 0.0 GB\n"
|
"Maximum GPU memory allocated: 0.0 GB\n"
|
||||||
]
|
]
|
||||||
@@ -315,7 +317,7 @@
|
|||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
" \n",
|
" \n",
|
||||||
"## 3. Weight loading"
|
"## 3. Basic weight loading"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -336,12 +338,12 @@
|
|||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "wCrQNbSJJO9w",
|
"id": "wCrQNbSJJO9w",
|
||||||
"outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994"
|
"outputId": "2623b399-bce6-4506-ec0b-c3c94729b80f"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"Maximum GPU memory allocated: 12.8 GB\n"
|
"Maximum GPU memory allocated: 12.8 GB\n"
|
||||||
]
|
]
|
||||||
@@ -386,12 +388,12 @@
|
|||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
"id": "DvlUn-nmmbuj",
|
"id": "DvlUn-nmmbuj",
|
||||||
"outputId": "11d3ab68-f570-4c1e-c631-fe5547026799"
|
"outputId": "7a9afbde-826f-4fb2-874d-feb6e8724834"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"Maximum GPU memory allocated: 0.0 GB\n"
|
"Maximum GPU memory allocated: 0.0 GB\n"
|
||||||
]
|
]
|
||||||
@@ -409,6 +411,64 @@
|
|||||||
"cleanup()"
|
"cleanup()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"- Let's test another common pattern that is very popular in practice:"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "IQ531-IuRuzD"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"start_memory_tracking()\n",
|
||||||
|
"\n",
|
||||||
|
"model = GPTModel(BASE_CONFIG)\n",
|
||||||
|
"\n",
|
||||||
|
"model.load_state_dict(\n",
|
||||||
|
" torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
|
||||||
|
")\n",
|
||||||
|
"model.to(device)\n",
|
||||||
|
"model.eval();\n",
|
||||||
|
"\n",
|
||||||
|
"print_memory_usage()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "2m54kzX5RxLX"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"# Test if the model works (no need to track memory here)\n",
|
||||||
|
"test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" model(test_input)\n",
|
||||||
|
"\n",
|
||||||
|
"del model, test_input, state_dict, param\n",
|
||||||
|
"cleanup()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "XWvQTRN4R2CM"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"- So, as peak memory is concerned, it doesn't make a difference whether we instantiate the model on the device first and then use `map_location=\"device\"` or load the weights into CPU memory first (`map_location=\"cpu\"`) and then move it to the device"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "UGjBD6GASS_y"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -434,7 +494,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -486,7 +546,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -539,7 +599,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "BrcWy0q-3Bbe"
|
"id": "BrcWy0q-3Bbe"
|
||||||
},
|
},
|
||||||
@@ -590,7 +650,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -648,7 +708,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -707,7 +767,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -754,7 +814,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "jvDVFpcaRISr"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
" \n",
|
" \n",
|
||||||
"## 6. Using `mmap=True` (recommmended)"
|
"## 6. Using `mmap=True` (recommmended)"
|
||||||
@@ -762,7 +824,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "w3H5gPygRISr"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n",
|
"- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n",
|
||||||
"- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n",
|
"- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n",
|
||||||
@@ -772,7 +836,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 37,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -808,7 +872,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "pGC0rBv4RISr"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"- The reason why the CPU RAM usage is so high is that there's enough CPU RAM available on this machine\n",
|
"- The reason why the CPU RAM usage is so high is that there's enough CPU RAM available on this machine\n",
|
||||||
"- However, if you were to run this on a machine with limited CPU RAM, the `mmap` approach would use less memory"
|
"- However, if you were to run this on a machine with limited CPU RAM, the `mmap` approach would use less memory"
|
||||||
@@ -816,7 +882,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "fd11QM8pRISr"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
" \n",
|
" \n",
|
||||||
"## 7. Other methods"
|
"## 7. Other methods"
|
||||||
@@ -824,7 +892,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "0U2Y6eo8RISr"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"- This notebook is focused on simple, built-in methods for loading weights in PyTorch\n",
|
"- This notebook is focused on simple, built-in methods for loading weights in PyTorch\n",
|
||||||
"- The recommended approach for limited CPU memory cases is the `mmap=True` approach explained enough\n",
|
"- The recommended approach for limited CPU memory cases is the `mmap=True` approach explained enough\n",
|
||||||
@@ -833,7 +903,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "2CgPEZUIb00w"
|
"id": "2CgPEZUIb00w"
|
||||||
},
|
},
|
||||||
@@ -855,7 +925,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"base_uri": "https://localhost:8080/"
|
"base_uri": "https://localhost:8080/"
|
||||||
@@ -908,12 +978,11 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"accelerator": "GPU",
|
"accelerator": "GPU",
|
||||||
"colab": {
|
"colab": {
|
||||||
"gpuType": "L4",
|
"gpuType": "T4",
|
||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
@@ -930,5 +999,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 4
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user