diff --git a/ch05/16_qwen3.5/qwen3_5_transformers.py b/ch05/16_qwen3.5/qwen3_5_transformers.py index a961db6..2a4bed2 100644 --- a/ch05/16_qwen3.5/qwen3_5_transformers.py +++ b/ch05/16_qwen3.5/qwen3_5_transformers.py @@ -1,7 +1,7 @@ """Qwen3.5 helper blocks copied from Hugging Face Transformers Source file: -transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py +https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py License: Apache License Version 2.0 License URL: https://github.com/huggingface/transformers/blob/main/LICENSE @@ -45,8 +45,6 @@ class Qwen3_5DynamicCache: pass -# Copied verbatim from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py class Qwen3_5RMSNormGated(nn.Module): def __init__(self, hidden_size, eps=1e-6, **kwargs): super().__init__() @@ -65,8 +63,6 @@ class Qwen3_5RMSNormGated(nn.Module): return hidden_states.to(input_dtype) -# Copied verbatim from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py def apply_mask_to_padding_states(hidden_states, attention_mask): """ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -79,8 +75,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -# Copied verbatim from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py def torch_causal_conv1d_update( hidden_states, conv_state, @@ -99,16 +93,12 @@ def torch_causal_conv1d_update( return out -# Copied verbatim from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py def l2norm(x, dim=-1, eps=1e-6): """This function is intended to align with the l2norm implementation in the FLA library.""" inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) return x * inv_norm -# Copied verbatim from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py def torch_chunk_gated_delta_rule( query, key, @@ -189,8 +179,6 @@ def torch_chunk_gated_delta_rule( return core_attn_out, last_recurrent_state -# Copied verbatim from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py def torch_recurrent_gated_delta_rule( query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False ): @@ -233,8 +221,6 @@ def torch_recurrent_gated_delta_rule( return core_attn_out, last_recurrent_state -# Copied from: -# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py # Minimal change: enforce config dtype at the end to avoid bf16/fp32 matmul mismatch # in a mixed notebook implementation class Qwen3_5GatedDeltaNet(nn.Module):