mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
426 lines
16 KiB
Python
426 lines
16 KiB
Python
"""Qwen3.5 helper blocks copied from Hugging Face Transformers
|
|
|
|
Source file:
|
|
transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
|
|
License: Apache License Version 2.0
|
|
License URL: https://github.com/huggingface/transformers/blob/main/LICENSE
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
# Notebook shims for optional fast kernels in transformers
|
|
causal_conv1d_fn = None
|
|
causal_conv1d_update = None
|
|
chunk_gated_delta_rule = None
|
|
fused_recurrent_gated_delta_rule = None
|
|
FusedRMSNormGated = None
|
|
ACT2FN = {"silu": F.silu}
|
|
is_fast_path_available = False
|
|
|
|
|
|
class _NotebookLogger:
|
|
def __init__(self):
|
|
self._seen = set()
|
|
|
|
def warning_once(self, msg):
|
|
if msg in self._seen:
|
|
return
|
|
self._seen.add(msg)
|
|
print(msg)
|
|
|
|
|
|
logger = _NotebookLogger()
|
|
|
|
|
|
# Placeholder types for copied annotations
|
|
class Qwen3_5Config:
|
|
pass
|
|
|
|
|
|
class Qwen3_5DynamicCache:
|
|
pass
|
|
|
|
|
|
# Copied verbatim from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
class Qwen3_5RMSNormGated(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states, gate=None):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
# Norm before gate
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
hidden_states = self.weight * hidden_states.to(input_dtype)
|
|
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
|
|
|
return hidden_states.to(input_dtype)
|
|
|
|
|
|
# Copied verbatim from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
"""
|
|
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
|
|
"""
|
|
# NOTE: attention mask is a 2D boolean tensor
|
|
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
|
dtype = hidden_states.dtype
|
|
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
|
|
|
return hidden_states
|
|
|
|
|
|
# Copied verbatim from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
def torch_causal_conv1d_update(
|
|
hidden_states,
|
|
conv_state,
|
|
weight,
|
|
bias=None,
|
|
activation=None,
|
|
):
|
|
_, hidden_size, seq_len = hidden_states.shape
|
|
state_len = conv_state.shape[-1]
|
|
|
|
hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
|
|
conv_state.copy_(hidden_states_new[:, :, -state_len:])
|
|
out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
|
|
out = F.silu(out[:, :, -seq_len:])
|
|
out = out.to(hidden_states.dtype)
|
|
return out
|
|
|
|
|
|
# Copied verbatim from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
def l2norm(x, dim=-1, eps=1e-6):
|
|
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
|
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
|
return x * inv_norm
|
|
|
|
|
|
# Copied verbatim from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
def torch_chunk_gated_delta_rule(
|
|
query,
|
|
key,
|
|
value,
|
|
g,
|
|
beta,
|
|
chunk_size=64,
|
|
initial_state=None,
|
|
output_final_state=False,
|
|
use_qk_l2norm_in_kernel=False,
|
|
):
|
|
initial_dtype = query.dtype
|
|
if use_qk_l2norm_in_kernel:
|
|
query = l2norm(query, dim=-1, eps=1e-6)
|
|
key = l2norm(key, dim=-1, eps=1e-6)
|
|
query, key, value, beta, g = [
|
|
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
|
]
|
|
|
|
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
|
v_head_dim = value.shape[-1]
|
|
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
|
query = F.pad(query, (0, 0, 0, pad_size))
|
|
key = F.pad(key, (0, 0, 0, pad_size))
|
|
value = F.pad(value, (0, 0, 0, pad_size))
|
|
beta = F.pad(beta, (0, pad_size))
|
|
g = F.pad(g, (0, pad_size))
|
|
total_sequence_length = sequence_length + pad_size
|
|
scale = 1 / (query.shape[-1] ** 0.5)
|
|
query = query * scale
|
|
|
|
v_beta = value * beta.unsqueeze(-1)
|
|
k_beta = key * beta.unsqueeze(-1)
|
|
# reshape to chunks
|
|
query, key, value, k_beta, v_beta = [
|
|
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
|
|
]
|
|
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
|
|
|
# chunk decay
|
|
g = g.cumsum(dim=-1)
|
|
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
|
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
|
for i in range(1, chunk_size):
|
|
row = attn[..., i, :i].clone()
|
|
sub = attn[..., :i, :i].clone()
|
|
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
|
value = attn @ v_beta
|
|
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
|
last_recurrent_state = (
|
|
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
|
if initial_state is None
|
|
else initial_state.to(value)
|
|
)
|
|
core_attn_out = torch.zeros_like(value)
|
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
|
|
|
# for each chunk
|
|
for i in range(0, total_sequence_length // chunk_size):
|
|
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
|
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
|
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
|
v_new = v_i - v_prime
|
|
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
|
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
|
last_recurrent_state = (
|
|
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
|
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
|
)
|
|
|
|
if not output_final_state:
|
|
last_recurrent_state = None
|
|
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
|
core_attn_out = core_attn_out[:, :, :sequence_length]
|
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
return core_attn_out, last_recurrent_state
|
|
|
|
|
|
# Copied verbatim from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
def torch_recurrent_gated_delta_rule(
|
|
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
|
|
):
|
|
initial_dtype = query.dtype
|
|
if use_qk_l2norm_in_kernel:
|
|
query = l2norm(query, dim=-1, eps=1e-6)
|
|
key = l2norm(key, dim=-1, eps=1e-6)
|
|
query, key, value, beta, g = [
|
|
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
|
]
|
|
|
|
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
|
v_head_dim = value.shape[-1]
|
|
scale = 1 / (query.shape[-1] ** 0.5)
|
|
query = query * scale
|
|
|
|
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
|
|
last_recurrent_state = (
|
|
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
|
if initial_state is None
|
|
else initial_state.to(value)
|
|
)
|
|
|
|
for i in range(sequence_length):
|
|
q_t = query[:, :, i]
|
|
k_t = key[:, :, i]
|
|
v_t = value[:, :, i]
|
|
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
|
|
beta_t = beta[:, :, i].unsqueeze(-1)
|
|
|
|
last_recurrent_state = last_recurrent_state * g_t
|
|
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
|
delta = (v_t - kv_mem) * beta_t
|
|
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
|
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
|
|
|
if not output_final_state:
|
|
last_recurrent_state = None
|
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
|
return core_attn_out, last_recurrent_state
|
|
|
|
|
|
# Copied from:
|
|
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
# Minimal change: enforce config dtype at the end to avoid bf16/fp32 matmul mismatch
|
|
# in a mixed notebook implementation
|
|
class Qwen3_5GatedDeltaNet(nn.Module):
|
|
def __init__(self, config, layer_idx):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.num_v_heads = config.linear_num_value_heads
|
|
self.num_k_heads = config.linear_num_key_heads
|
|
self.head_k_dim = config.linear_key_head_dim
|
|
self.head_v_dim = config.linear_value_head_dim
|
|
self.key_dim = self.head_k_dim * self.num_k_heads
|
|
self.value_dim = self.head_v_dim * self.num_v_heads
|
|
|
|
self.conv_kernel_size = config.linear_conv_kernel_dim
|
|
self.layer_idx = layer_idx
|
|
self.activation = config.hidden_act
|
|
self.act = ACT2FN[config.hidden_act]
|
|
self.layer_norm_epsilon = config.rms_norm_eps
|
|
|
|
# QKV
|
|
self.conv_dim = self.key_dim * 2 + self.value_dim
|
|
self.conv1d = nn.Conv1d(
|
|
in_channels=self.conv_dim,
|
|
out_channels=self.conv_dim,
|
|
bias=False,
|
|
kernel_size=self.conv_kernel_size,
|
|
groups=self.conv_dim,
|
|
padding=self.conv_kernel_size - 1,
|
|
)
|
|
|
|
# time step projection (discretization)
|
|
# instantiate once and copy inv_dt in init_weights of PretrainedModel
|
|
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
|
|
|
|
A = torch.empty(self.num_v_heads).uniform_(0, 16)
|
|
self.A_log = nn.Parameter(torch.log(A))
|
|
|
|
self.norm = (
|
|
Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
|
|
if FusedRMSNormGated is None
|
|
else FusedRMSNormGated(
|
|
self.head_v_dim,
|
|
eps=self.layer_norm_epsilon,
|
|
activation=self.activation,
|
|
device=torch.cuda.current_device(),
|
|
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
|
|
)
|
|
)
|
|
|
|
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
|
|
|
self.causal_conv1d_fn = causal_conv1d_fn
|
|
self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
|
|
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
|
|
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
|
|
|
|
if not is_fast_path_available:
|
|
logger.warning_once(
|
|
"The fast path is not available because one of the required library is not installed. Falling back to "
|
|
"torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
|
|
" https://github.com/Dao-AILab/causal-conv1d"
|
|
)
|
|
|
|
self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
|
|
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
|
|
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
|
|
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
|
|
|
|
# Notebook adaptation for dtype consistency.
|
|
if config.dtype is not None:
|
|
self.to(dtype=config.dtype)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
cache_params=None,
|
|
cache_position=None,
|
|
attention_mask=None,
|
|
):
|
|
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
|
|
|
# Set up dimensions for reshapes later
|
|
batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
use_precomputed_states = (
|
|
cache_params is not None
|
|
and cache_params.has_previous_state
|
|
and seq_len == 1
|
|
and cache_position is not None
|
|
)
|
|
|
|
# getting projected states from cache if it exists
|
|
if cache_params is not None:
|
|
conv_state = cache_params.conv_states[self.layer_idx]
|
|
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
|
|
|
mixed_qkv = self.in_proj_qkv(hidden_states)
|
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
|
|
z = self.in_proj_z(hidden_states)
|
|
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
|
|
|
b = self.in_proj_b(hidden_states)
|
|
a = self.in_proj_a(hidden_states)
|
|
|
|
if use_precomputed_states:
|
|
# 2. Convolution sequence transformation
|
|
# NOTE: the conv state is updated in `causal_conv1d_update`
|
|
mixed_qkv = self.causal_conv1d_update(
|
|
mixed_qkv,
|
|
conv_state,
|
|
self.conv1d.weight.squeeze(1),
|
|
self.conv1d.bias,
|
|
self.activation,
|
|
)
|
|
else:
|
|
if cache_params is not None:
|
|
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
|
cache_params.conv_states[self.layer_idx] = conv_state
|
|
if self.causal_conv1d_fn is not None:
|
|
mixed_qkv = self.causal_conv1d_fn(
|
|
x=mixed_qkv,
|
|
weight=self.conv1d.weight.squeeze(1),
|
|
bias=self.conv1d.bias,
|
|
activation=self.activation,
|
|
seq_idx=None,
|
|
)
|
|
else:
|
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
|
|
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
query, key, value = torch.split(
|
|
mixed_qkv,
|
|
[
|
|
self.key_dim,
|
|
self.key_dim,
|
|
self.value_dim,
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
|
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
|
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
|
|
|
beta = b.sigmoid()
|
|
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
|
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
|
if self.num_v_heads // self.num_k_heads > 1:
|
|
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
|
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
|
|
|
if not use_precomputed_states:
|
|
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
|
query,
|
|
key,
|
|
value,
|
|
g=g,
|
|
beta=beta,
|
|
initial_state=None,
|
|
output_final_state=cache_params is not None,
|
|
use_qk_l2norm_in_kernel=True,
|
|
)
|
|
|
|
else:
|
|
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
|
|
query,
|
|
key,
|
|
value,
|
|
g=g,
|
|
beta=beta,
|
|
initial_state=recurrent_state,
|
|
output_final_state=cache_params is not None,
|
|
use_qk_l2norm_in_kernel=True,
|
|
)
|
|
|
|
# Update cache
|
|
if cache_params is not None:
|
|
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
|
|
|
|
# reshape input data into 2D tensor
|
|
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
|
|
z = z.reshape(-1, self.head_v_dim)
|
|
core_attn_out = self.norm(core_attn_out, z)
|
|
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
|
|
|
|
output = self.out_proj(core_attn_out)
|
|
return output
|