Disable mask saving as weight in Llama 3 model (#604)

* Disable mask saving as weight

* update pixi

* update pixi
This commit is contained in:
Sebastian Raschka
2025-04-06 09:33:36 -05:00
committed by GitHub
parent f1434652f2
commit 67e0680210
3 changed files with 8 additions and 4 deletions

View File

@@ -42,6 +42,7 @@ jobs:
- name: List installed packages
run: |
pixi list --environment tests
pixi run --environment tests pip install "huggingface-hub>=0.30.0,<1.0"
- name: Test Selected Python Scripts
shell: pixi run --environment tests bash -e {0}

View File

@@ -368,7 +368,10 @@
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool())\n",
" self.register_buffer(\n",
" \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n",
" persistent=False\n",
" )\n",
" cfg[\"rope_base\"] = rescale_theta(\n",
" cfg[\"rope_base\"],\n",
" cfg[\"orig_context_length\"],\n",

View File

@@ -266,10 +266,10 @@
"\n",
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" self.register_buffer(\"mask\", mask)\n",
" self.register_buffer(\"mask\", mask, persistent=False)\n",
"\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",