mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Gated DeltaNet updates (#926)
This commit is contained in:
committed by
GitHub
parent
d7f178d28b
commit
57430d2a13
@@ -132,8 +132,10 @@ However, as shown in the figure above, the "gated" in the Gated DeltaNet also re
|
|||||||
- `α` (decay gate) controls how fast the memory decays or resets over time,
|
- `α` (decay gate) controls how fast the memory decays or resets over time,
|
||||||
- `β` (update gate) controls how strongly new inputs modify the state.
|
- `β` (update gate) controls how strongly new inputs modify the state.
|
||||||
|
|
||||||
|
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team).
|
||||||
|
|
||||||
|
(Note that some implementations refer to the decay gate as `gk` (gate for step k), where `exp(gk)` matches the paper's $\alpha_t$. To keep this relationship explicit, the snippet below separates the log-space gate `alpha_log` from the exponentiated decay `alpha`.)
|
||||||
|
|
||||||
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team):
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@@ -187,9 +189,10 @@ class GatedDeltaNet(nn.Module):
|
|||||||
####################################################
|
####################################################
|
||||||
### NEW: Compute delta rule gates
|
### NEW: Compute delta rule gates
|
||||||
beta = torch.sigmoid(self.W_beta(x))
|
beta = torch.sigmoid(self.W_beta(x))
|
||||||
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus(
|
alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus(
|
||||||
self.W_alpha(x) + self.dt_bias
|
self.W_alpha(x) + self.dt_bias
|
||||||
)
|
)
|
||||||
|
alpha = alpha_log.exp()
|
||||||
gate = self.W_gate(x)
|
gate = self.W_gate(x)
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
@@ -223,7 +226,7 @@ class GatedDeltaNet(nn.Module):
|
|||||||
b_t = beta[:, :, t]
|
b_t = beta[:, :, t]
|
||||||
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
|
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
S = S * a_t.exp()
|
S = S * a_t
|
||||||
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
|
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||||
delta = (v_t - kv_mem) * b_t
|
delta = (v_t - kv_mem) * b_t
|
||||||
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
||||||
@@ -287,7 +290,7 @@ for t in range(num_tokens):
|
|||||||
b_t = beta[:, :, t]
|
b_t = beta[:, :, t]
|
||||||
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
|
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
S = S * a_t.exp()
|
S = S * a_t
|
||||||
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
|
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||||
delta = (v_t - kv_mem) * b_t
|
delta = (v_t - kv_mem) * b_t
|
||||||
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
||||||
@@ -298,7 +301,7 @@ And the gates control how that memory changes:
|
|||||||
|
|
||||||
- α (`alpha`) regulates how much of the old memory to forget (decay).
|
- α (`alpha`) regulates how much of the old memory to forget (decay).
|
||||||
|
|
||||||
- β (`alpha`) regulates how much the current token at time step *t* updates the memory.
|
- β (`beta`) regulates how much the current token at time step *t* updates the memory.
|
||||||
|
|
||||||
(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)
|
(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user