diff --git a/.github/workflows/basic-tests-pixi.yml b/.github/workflows/basic-tests-pixi.yml index 185c4af..e661151 100644 --- a/.github/workflows/basic-tests-pixi.yml +++ b/.github/workflows/basic-tests-pixi.yml @@ -42,6 +42,7 @@ jobs: - name: List installed packages run: | pixi list --environment tests + pixi run --environment tests pip install "huggingface-hub>=0.30.0,<1.0" - name: Test Selected Python Scripts shell: pixi run --environment tests bash -e {0} diff --git a/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb b/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb index 96aee54..a98f06d 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb @@ -368,7 +368,10 @@ " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", " # Reusuable utilities\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool())\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n", + " persistent=False\n", + " )\n", " cfg[\"rope_base\"] = rescale_theta(\n", " cfg[\"rope_base\"],\n", " cfg[\"orig_context_length\"],\n", diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index 4da06ff..9275955 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -266,10 +266,10 @@ "\n", " # Fetch buffers using SharedBuffers\n", " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n", - " self.register_buffer(\"mask\", mask)\n", + " self.register_buffer(\"mask\", mask, persistent=False)\n", "\n", - " self.register_buffer(\"cos\", cos)\n", - " self.register_buffer(\"sin\", sin)\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n",