mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Qwen3 Coder Flash & MoE from Scratch (#760)
* Qwen3 Coder Flash & MoE from Scratch * update * refinements * updates * update * update * update
This commit is contained in:
committed by
GitHub
parent
145322ded8
commit
f92b40e4ab
@@ -28,3 +28,27 @@ def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cach
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
def generate_text_simple_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
cache = KVCache(n_layers=model.cfg["n_layers"])
|
||||
model.reset_kv_cache()
|
||||
|
||||
# Prime the cache with the initial context
|
||||
logits = model(token_ids, cache=cache)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
if eos_token_id is not None and torch.all(next_token == eos_token_id):
|
||||
break
|
||||
|
||||
yield next_token
|
||||
|
||||
token_ids = torch.cat([token_ids, next_token], dim=1)
|
||||
|
||||
# Feed only the new token to the model; cache handles history
|
||||
logits = model(next_token, cache=cache)
|
||||
|
||||
Reference in New Issue
Block a user