Gated DeltaNet updates (#926)

This commit is contained in:
Sebastian Raschka
2025-12-18 20:28:53 -06:00
committed by GitHub
parent d7f178d28b
commit 57430d2a13

View File

@@ -132,8 +132,10 @@ However, as shown in the figure above, the "gated" in the Gated DeltaNet also re
- `α` (decay gate) controls how fast the memory decays or resets over time, - `α` (decay gate) controls how fast the memory decays or resets over time,
- `β` (update gate) controls how strongly new inputs modify the state. - `β` (update gate) controls how strongly new inputs modify the state.
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team).
(Note that some implementations refer to the decay gate as `gk` (gate for step k), where `exp(gk)` matches the paper's $\alpha_t$. To keep this relationship explicit, the snippet below separates the log-space gate `alpha_log` from the exponentiated decay `alpha`.)
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team):
```python ```python
import torch import torch
@@ -187,9 +189,10 @@ class GatedDeltaNet(nn.Module):
#################################################### ####################################################
### NEW: Compute delta rule gates ### NEW: Compute delta rule gates
beta = torch.sigmoid(self.W_beta(x)) beta = torch.sigmoid(self.W_beta(x))
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus( alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus(
self.W_alpha(x) + self.dt_bias self.W_alpha(x) + self.dt_bias
) )
alpha = alpha_log.exp()
gate = self.W_gate(x) gate = self.W_gate(x)
#################################################### ####################################################
@@ -223,7 +226,7 @@ class GatedDeltaNet(nn.Module):
b_t = beta[:, :, t] b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1) a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
S = S * a_t.exp() S = S * a_t
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2) S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
@@ -287,7 +290,7 @@ for t in range(num_tokens):
b_t = beta[:, :, t] b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1) a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
S = S * a_t.exp() S = S * a_t
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2) S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
@@ -298,7 +301,7 @@ And the gates control how that memory changes:
- α (`alpha`) regulates how much of the old memory to forget (decay). - α (`alpha`) regulates how much of the old memory to forget (decay).
- β (`alpha`) regulates how much the current token at time step *t* updates the memory. - β (`beta`) regulates how much the current token at time step *t* updates the memory.
(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.) (And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)