From 67e068021066af217b67cc7801d3e6658576cb24 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sun, 6 Apr 2025 09:33:36 -0500 Subject: [PATCH] Disable mask saving as weight in Llama 3 model (#604) * Disable mask saving as weight * update pixi * update pixi --- .github/workflows/basic-tests-pixi.yml | 1 + ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb | 5 ++++- ch05/07_gpt_to_llama/standalone-llama32.ipynb | 6 +++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/basic-tests-pixi.yml b/.github/workflows/basic-tests-pixi.yml index 185c4af..e661151 100644 --- a/.github/workflows/basic-tests-pixi.yml +++ b/.github/workflows/basic-tests-pixi.yml @@ -42,6 +42,7 @@ jobs: - name: List installed packages run: | pixi list --environment tests + pixi run --environment tests pip install "huggingface-hub>=0.30.0,<1.0" - name: Test Selected Python Scripts shell: pixi run --environment tests bash -e {0} diff --git a/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb b/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb index 96aee54..a98f06d 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb @@ -368,7 +368,10 @@ " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", " # Reusuable utilities\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool())\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n", + " persistent=False\n", + " )\n", " cfg[\"rope_base\"] = rescale_theta(\n", " cfg[\"rope_base\"],\n", " cfg[\"orig_context_length\"],\n", diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index 4da06ff..9275955 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -266,10 +266,10 @@ "\n", " # Fetch buffers using SharedBuffers\n", " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n", - " self.register_buffer(\"mask\", mask)\n", + " self.register_buffer(\"mask\", mask, persistent=False)\n", "\n", - " self.register_buffer(\"cos\", cos)\n", - " self.register_buffer(\"sin\", sin)\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n",