mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Qwen3 Coder Flash & MoE from Scratch (#760)
* Qwen3 Coder Flash & MoE from Scratch * update * refinements * updates * update * update * update
This commit is contained in:
committed by
GitHub
parent
d6213a398a
commit
71ef67be46
@@ -28,3 +28,27 @@ def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cach
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
def generate_text_simple_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
cache = KVCache(n_layers=model.cfg["n_layers"])
|
||||
model.reset_kv_cache()
|
||||
|
||||
# Prime the cache with the initial context
|
||||
logits = model(token_ids, cache=cache)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
if eos_token_id is not None and torch.all(next_token == eos_token_id):
|
||||
break
|
||||
|
||||
yield next_token
|
||||
|
||||
token_ids = torch.cat([token_ids, next_token], dim=1)
|
||||
|
||||
# Feed only the new token to the model; cache handles history
|
||||
logits = model(next_token, cache=cache)
|
||||
|
||||
@@ -29,7 +29,7 @@ class Qwen3Model(nn.Module):
|
||||
self.final_norm = RMSNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
# Reusable utilities
|
||||
if cfg["head_dim"] is None:
|
||||
head_dim = cfg["emb_dim"] // cfg["n_heads"]
|
||||
else:
|
||||
@@ -94,7 +94,10 @@ class TransformerBlock(nn.Module):
|
||||
qk_norm=cfg["qk_norm"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
if "num_experts" in cfg and cfg["num_experts"] > 0:
|
||||
self.ff = MoEFeedForward(cfg)
|
||||
else:
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
|
||||
@@ -128,6 +131,46 @@ class FeedForward(nn.Module):
|
||||
return self.fc3(x)
|
||||
|
||||
|
||||
class MoEFeedForward(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.num_experts_per_tok = cfg["num_experts_per_tok"]
|
||||
self.num_experts = cfg["num_experts"]
|
||||
self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory)
|
||||
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
for _ in range(cfg["num_experts"])])
|
||||
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
for _ in range(cfg["num_experts"])])
|
||||
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device)
|
||||
for _ in range(cfg["num_experts"])])
|
||||
|
||||
def forward(self, x):
|
||||
scores = self.gate(x) # (b, seq_len, num_experts)
|
||||
topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
|
||||
topk_probs = torch.softmax(topk_scores, dim=-1)
|
||||
|
||||
expert_outputs = []
|
||||
for e in range(self.num_experts):
|
||||
hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)
|
||||
out = self.fc3[e](hidden)
|
||||
expert_outputs.append(out.unsqueeze(-2))
|
||||
expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)
|
||||
|
||||
gating_probs = torch.zeros_like(scores)
|
||||
|
||||
for i in range(self.num_experts_per_tok):
|
||||
indices = topk_indices[..., i:i+1]
|
||||
prob = topk_probs[..., i:i+1]
|
||||
gating_probs.scatter_(dim=-1, index=indices, src=prob)
|
||||
gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)
|
||||
|
||||
# Weighted sum over experts
|
||||
y = (gating_probs * expert_outputs).sum(dim=-2)
|
||||
return y
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
|
||||
|
||||
Reference in New Issue
Block a user