Gated DeltaNet updates (#926)

This commit is contained in:
Sebastian Raschka
2025-12-18 20:28:53 -06:00
committed by GitHub
parent d7f178d28b
commit 57430d2a13

View File

@@ -12,7 +12,7 @@ Both Qwen3-Next and Kimi Linear use a 3:1 ratio, meaning for every three transfo
## Introduction and Overview
Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.
Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.
Kimi Linear modifies the linear attention mechanism of Qwen3-Next by the Kimi Delta Attention (KDA) mechanism, which is essentially a refinement of Gated DeltaNet. Whereas Qwen3-Next applies a scalar gate (one value per attention head) to control the memory decay rate, Kimi Linear replaces it with a channel-wise gating for each feature dimension. According to the authors, this gives more control over the memory, and this, in turn, improves long-context reasoning.
@@ -96,7 +96,7 @@ class GatedMultiHeadAttention(nn.Module):
context = context.reshape(b, num_tokens, self.d_out)
####################################################
### NEW: Add gate
### NEW: Add gate
context = context * torch.sigmoid(gate)
####################################################
out = self.out_proj(context)
@@ -115,11 +115,11 @@ As we can see, after computing attention as usual, the model uses a separate gat
Now, what is Gated DeltaNet? Gated DeltaNet (short for *Gated Delta Network*) is Qwen3-Next's linear-attention layer, which is intended as an alternative to standard softmax attention. It was adopted from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper as mentioned earlier.
Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.
Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.
Mamba is a state-space model (an alternative to transformers), a big topic that deserves separate coverage in the future.
The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).
The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).
(Side note: Readers with classic machine learning literature can think of this as similar to Hebbian learning inspired by biology: "Cells that fire together wire together." It's basically a precursor of the perceptron update rule and gradient descent-based learning, but without supervision.)
@@ -132,8 +132,10 @@ However, as shown in the figure above, the "gated" in the Gated DeltaNet also re
- `α` (decay gate) controls how fast the memory decays or resets over time,
- `β` (update gate) controls how strongly new inputs modify the state.
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team).
(Note that some implementations refer to the decay gate as `gk` (gate for step k), where `exp(gk)` matches the paper's $\alpha_t$. To keep this relationship explicit, the snippet below separates the log-space gate `alpha_log` from the exponentiated decay `alpha`.)
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team):
```python
import torch
@@ -161,7 +163,7 @@ class GatedDeltaNet(nn.Module):
### NEW: Gates for delta rule and output gating
self.W_gate = nn.Linear(d_in, d_out, bias=False)
self.W_beta = nn.Linear(d_in, d_out, bias=False)
# Note: The decay gate alpha corresponds to
# A_log + W_alpha(x) + dt_bias
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
@@ -172,7 +174,7 @@ class GatedDeltaNet(nn.Module):
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
# but the bias is separate for interpretability and
# to mimic the official implementation
self.norm = nn.RMSNorm(self.head_dim, eps=1e-6)
####################################################
@@ -187,9 +189,10 @@ class GatedDeltaNet(nn.Module):
####################################################
### NEW: Compute delta rule gates
beta = torch.sigmoid(self.W_beta(x))
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus(
alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus(
self.W_alpha(x) + self.dt_bias
)
alpha = alpha_log.exp()
gate = self.W_gate(x)
####################################################
@@ -223,7 +226,7 @@ class GatedDeltaNet(nn.Module):
b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
S = S * a_t.exp()
S = S * a_t
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
@@ -287,7 +290,7 @@ for t in range(num_tokens):
b_t = beta[:, :, t]
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
S = S * a_t.exp()
S = S * a_t
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * b_t
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
@@ -298,7 +301,7 @@ And the gates control how that memory changes:
- α (`alpha`) regulates how much of the old memory to forget (decay).
- β (`alpha`) regulates how much the current token at time step *t* updates the memory.
- β (`beta`) regulates how much the current token at time step *t* updates the memory.
(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)