mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Disable mask saving as weight in Llama 3 model (#604)
* Disable mask saving as weight * update pixi * update pixi
This commit is contained in:
committed by
GitHub
parent
f1434652f2
commit
67e0680210
1
.github/workflows/basic-tests-pixi.yml
vendored
1
.github/workflows/basic-tests-pixi.yml
vendored
@@ -42,6 +42,7 @@ jobs:
|
||||
- name: List installed packages
|
||||
run: |
|
||||
pixi list --environment tests
|
||||
pixi run --environment tests pip install "huggingface-hub>=0.30.0,<1.0"
|
||||
|
||||
- name: Test Selected Python Scripts
|
||||
shell: pixi run --environment tests bash -e {0}
|
||||
|
||||
@@ -368,7 +368,10 @@
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool())\n",
|
||||
" self.register_buffer(\n",
|
||||
" \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n",
|
||||
" persistent=False\n",
|
||||
" )\n",
|
||||
" cfg[\"rope_base\"] = rescale_theta(\n",
|
||||
" cfg[\"rope_base\"],\n",
|
||||
" cfg[\"orig_context_length\"],\n",
|
||||
|
||||
@@ -266,10 +266,10 @@
|
||||
"\n",
|
||||
" # Fetch buffers using SharedBuffers\n",
|
||||
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
|
||||
" self.register_buffer(\"mask\", mask)\n",
|
||||
" self.register_buffer(\"mask\", mask, persistent=False)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\"cos\", cos)\n",
|
||||
" self.register_buffer(\"sin\", sin)\n",
|
||||
" self.register_buffer(\"cos\", cos, persistent=False)\n",
|
||||
" self.register_buffer(\"sin\", sin, persistent=False)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
|
||||
Reference in New Issue
Block a user