diff --git a/ch04/08_deltanet/README.md b/ch04/08_deltanet/README.md index 257e15f..9741fb7 100644 --- a/ch04/08_deltanet/README.md +++ b/ch04/08_deltanet/README.md @@ -12,7 +12,7 @@ Both Qwen3-Next and Kimi Linear use a 3:1 ratio, meaning for every three transfo ## Introduction and Overview -Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism. +Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism. Kimi Linear modifies the linear attention mechanism of Qwen3-Next by the Kimi Delta Attention (KDA) mechanism, which is essentially a refinement of Gated DeltaNet. Whereas Qwen3-Next applies a scalar gate (one value per attention head) to control the memory decay rate, Kimi Linear replaces it with a channel-wise gating for each feature dimension. According to the authors, this gives more control over the memory, and this, in turn, improves long-context reasoning. @@ -96,7 +96,7 @@ class GatedMultiHeadAttention(nn.Module): context = context.reshape(b, num_tokens, self.d_out) #################################################### - ### NEW: Add gate + ### NEW: Add gate context = context * torch.sigmoid(gate) #################################################### out = self.out_proj(context) @@ -115,11 +115,11 @@ As we can see, after computing attention as usual, the model uses a separate gat Now, what is Gated DeltaNet? Gated DeltaNet (short for *Gated Delta Network*) is Qwen3-Next's linear-attention layer, which is intended as an alternative to standard softmax attention. It was adopted from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper as mentioned earlier. -Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule. +Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule. Mamba is a state-space model (an alternative to transformers), a big topic that deserves separate coverage in the future. -The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later). +The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later). (Side note: Readers with classic machine learning literature can think of this as similar to Hebbian learning inspired by biology: "Cells that fire together wire together." It's basically a precursor of the perceptron update rule and gradient descent-based learning, but without supervision.) @@ -132,8 +132,10 @@ However, as shown in the figure above, the "gated" in the Gated DeltaNet also re - `α` (decay gate) controls how fast the memory decays or resets over time, - `β` (update gate) controls how strongly new inputs modify the state. +In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team). + +(Note that some implementations refer to the decay gate as `gk` (gate for step k), where `exp(gk)` matches the paper's $\alpha_t$. To keep this relationship explicit, the snippet below separates the log-space gate `alpha_log` from the exponentiated decay `alpha`.) -In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team): ```python import torch @@ -161,7 +163,7 @@ class GatedDeltaNet(nn.Module): ### NEW: Gates for delta rule and output gating self.W_gate = nn.Linear(d_in, d_out, bias=False) self.W_beta = nn.Linear(d_in, d_out, bias=False) - + # Note: The decay gate alpha corresponds to # A_log + W_alpha(x) + dt_bias self.W_alpha = nn.Linear(d_in, num_heads, bias=False) @@ -172,7 +174,7 @@ class GatedDeltaNet(nn.Module): # W_alpha = nn.Linear(d_in, num_heads, bias=True) # but the bias is separate for interpretability and # to mimic the official implementation - + self.norm = nn.RMSNorm(self.head_dim, eps=1e-6) #################################################### @@ -187,9 +189,10 @@ class GatedDeltaNet(nn.Module): #################################################### ### NEW: Compute delta rule gates beta = torch.sigmoid(self.W_beta(x)) - alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus( + alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus( self.W_alpha(x) + self.dt_bias ) + alpha = alpha_log.exp() gate = self.W_gate(x) #################################################### @@ -223,7 +226,7 @@ class GatedDeltaNet(nn.Module): b_t = beta[:, :, t] a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1) - S = S * a_t.exp() + S = S * a_t kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * b_t S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2) @@ -287,7 +290,7 @@ for t in range(num_tokens): b_t = beta[:, :, t] a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1) - S = S * a_t.exp() + S = S * a_t kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * b_t S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2) @@ -298,7 +301,7 @@ And the gates control how that memory changes: - α (`alpha`) regulates how much of the old memory to forget (decay). -- β (`alpha`) regulates how much the current token at time step *t* updates the memory. +- β (`beta`) regulates how much the current token at time step *t* updates the memory. (And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)