mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Gated DeltaNet updates (#926)
This commit is contained in:
committed by
GitHub
parent
d7f178d28b
commit
57430d2a13
@@ -12,7 +12,7 @@ Both Qwen3-Next and Kimi Linear use a 3:1 ratio, meaning for every three transfo
|
||||
|
||||
## Introduction and Overview
|
||||
|
||||
Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.
|
||||
Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.
|
||||
|
||||
Kimi Linear modifies the linear attention mechanism of Qwen3-Next by the Kimi Delta Attention (KDA) mechanism, which is essentially a refinement of Gated DeltaNet. Whereas Qwen3-Next applies a scalar gate (one value per attention head) to control the memory decay rate, Kimi Linear replaces it with a channel-wise gating for each feature dimension. According to the authors, this gives more control over the memory, and this, in turn, improves long-context reasoning.
|
||||
|
||||
@@ -96,7 +96,7 @@ class GatedMultiHeadAttention(nn.Module):
|
||||
context = context.reshape(b, num_tokens, self.d_out)
|
||||
|
||||
####################################################
|
||||
### NEW: Add gate
|
||||
### NEW: Add gate
|
||||
context = context * torch.sigmoid(gate)
|
||||
####################################################
|
||||
out = self.out_proj(context)
|
||||
@@ -115,11 +115,11 @@ As we can see, after computing attention as usual, the model uses a separate gat
|
||||
|
||||
Now, what is Gated DeltaNet? Gated DeltaNet (short for *Gated Delta Network*) is Qwen3-Next's linear-attention layer, which is intended as an alternative to standard softmax attention. It was adopted from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper as mentioned earlier.
|
||||
|
||||
Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.
|
||||
Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.
|
||||
|
||||
Mamba is a state-space model (an alternative to transformers), a big topic that deserves separate coverage in the future.
|
||||
|
||||
The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).
|
||||
The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).
|
||||
|
||||
(Side note: Readers with classic machine learning literature can think of this as similar to Hebbian learning inspired by biology: "Cells that fire together wire together." It's basically a precursor of the perceptron update rule and gradient descent-based learning, but without supervision.)
|
||||
|
||||
@@ -132,8 +132,10 @@ However, as shown in the figure above, the "gated" in the Gated DeltaNet also re
|
||||
- `α` (decay gate) controls how fast the memory decays or resets over time,
|
||||
- `β` (update gate) controls how strongly new inputs modify the state.
|
||||
|
||||
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team).
|
||||
|
||||
(Note that some implementations refer to the decay gate as `gk` (gate for step k), where `exp(gk)` matches the paper's $\alpha_t$. To keep this relationship explicit, the snippet below separates the log-space gate `alpha_log` from the exponentiated decay `alpha`.)
|
||||
|
||||
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team):
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -161,7 +163,7 @@ class GatedDeltaNet(nn.Module):
|
||||
### NEW: Gates for delta rule and output gating
|
||||
self.W_gate = nn.Linear(d_in, d_out, bias=False)
|
||||
self.W_beta = nn.Linear(d_in, d_out, bias=False)
|
||||
|
||||
|
||||
# Note: The decay gate alpha corresponds to
|
||||
# A_log + W_alpha(x) + dt_bias
|
||||
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
|
||||
@@ -172,7 +174,7 @@ class GatedDeltaNet(nn.Module):
|
||||
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
|
||||
# but the bias is separate for interpretability and
|
||||
# to mimic the official implementation
|
||||
|
||||
|
||||
self.norm = nn.RMSNorm(self.head_dim, eps=1e-6)
|
||||
####################################################
|
||||
|
||||
@@ -187,9 +189,10 @@ class GatedDeltaNet(nn.Module):
|
||||
####################################################
|
||||
### NEW: Compute delta rule gates
|
||||
beta = torch.sigmoid(self.W_beta(x))
|
||||
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus(
|
||||
alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus(
|
||||
self.W_alpha(x) + self.dt_bias
|
||||
)
|
||||
alpha = alpha_log.exp()
|
||||
gate = self.W_gate(x)
|
||||
####################################################
|
||||
|
||||
@@ -223,7 +226,7 @@ class GatedDeltaNet(nn.Module):
|
||||
b_t = beta[:, :, t]
|
||||
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
S = S * a_t.exp()
|
||||
S = S * a_t
|
||||
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
delta = (v_t - kv_mem) * b_t
|
||||
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
||||
@@ -287,7 +290,7 @@ for t in range(num_tokens):
|
||||
b_t = beta[:, :, t]
|
||||
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
S = S * a_t.exp()
|
||||
S = S * a_t
|
||||
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
delta = (v_t - kv_mem) * b_t
|
||||
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
||||
@@ -298,7 +301,7 @@ And the gates control how that memory changes:
|
||||
|
||||
- α (`alpha`) regulates how much of the old memory to forget (decay).
|
||||
|
||||
- β (`alpha`) regulates how much the current token at time step *t* updates the memory.
|
||||
- β (`beta`) regulates how much the current token at time step *t* updates the memory.
|
||||
|
||||
(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user