mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Use full HF url
This commit is contained in:
committed by
GitHub
parent
7892ec9435
commit
ae8eebf0d7
@@ -1,7 +1,7 @@
|
|||||||
"""Qwen3.5 helper blocks copied from Hugging Face Transformers
|
"""Qwen3.5 helper blocks copied from Hugging Face Transformers
|
||||||
|
|
||||||
Source file:
|
Source file:
|
||||||
transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
||||||
|
|
||||||
License: Apache License Version 2.0
|
License: Apache License Version 2.0
|
||||||
License URL: https://github.com/huggingface/transformers/blob/main/LICENSE
|
License URL: https://github.com/huggingface/transformers/blob/main/LICENSE
|
||||||
@@ -45,8 +45,6 @@ class Qwen3_5DynamicCache:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Copied verbatim from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
class Qwen3_5RMSNormGated(nn.Module):
|
class Qwen3_5RMSNormGated(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -65,8 +63,6 @@ class Qwen3_5RMSNormGated(nn.Module):
|
|||||||
return hidden_states.to(input_dtype)
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied verbatim from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
def apply_mask_to_padding_states(hidden_states, attention_mask):
|
def apply_mask_to_padding_states(hidden_states, attention_mask):
|
||||||
"""
|
"""
|
||||||
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
|
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||||
@@ -79,8 +75,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied verbatim from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
def torch_causal_conv1d_update(
|
def torch_causal_conv1d_update(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
conv_state,
|
conv_state,
|
||||||
@@ -99,16 +93,12 @@ def torch_causal_conv1d_update(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
# Copied verbatim from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
def l2norm(x, dim=-1, eps=1e-6):
|
def l2norm(x, dim=-1, eps=1e-6):
|
||||||
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||||
return x * inv_norm
|
return x * inv_norm
|
||||||
|
|
||||||
|
|
||||||
# Copied verbatim from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
def torch_chunk_gated_delta_rule(
|
def torch_chunk_gated_delta_rule(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -189,8 +179,6 @@ def torch_chunk_gated_delta_rule(
|
|||||||
return core_attn_out, last_recurrent_state
|
return core_attn_out, last_recurrent_state
|
||||||
|
|
||||||
|
|
||||||
# Copied verbatim from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
def torch_recurrent_gated_delta_rule(
|
def torch_recurrent_gated_delta_rule(
|
||||||
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
|
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
|
||||||
):
|
):
|
||||||
@@ -233,8 +221,6 @@ def torch_recurrent_gated_delta_rule(
|
|||||||
return core_attn_out, last_recurrent_state
|
return core_attn_out, last_recurrent_state
|
||||||
|
|
||||||
|
|
||||||
# Copied from:
|
|
||||||
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
|
||||||
# Minimal change: enforce config dtype at the end to avoid bf16/fp32 matmul mismatch
|
# Minimal change: enforce config dtype at the end to avoid bf16/fp32 matmul mismatch
|
||||||
# in a mixed notebook implementation
|
# in a mixed notebook implementation
|
||||||
class Qwen3_5GatedDeltaNet(nn.Module):
|
class Qwen3_5GatedDeltaNet(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user