diff --git a/.gitignore b/.gitignore index 2785ab7..03abb4c 100644 --- a/.gitignore +++ b/.gitignore @@ -144,6 +144,7 @@ ch0?/0?_user_interface/.files *.lock # Temporary and OS-related files +chainlit.md Untitled.ipynb .DS_Store diff --git a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py index 064e970..be5f3ba 100644 --- a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py +++ b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py @@ -143,23 +143,15 @@ async def main(message: chainlit.Message): await out_msg.send() # 3) Stream generation - with torch.no_grad(): - for tok in generate_text_simple_stream( - model=MODEL, - token_ids=input_ids_tensor, - max_new_tokens=MAX_NEW_TOKENS, - # eos_token_id=TOKENIZER.eos_token_id - ): - # Normalize token to int - if torch.is_tensor(tok): - tok = int(tok.view(-1)[0].item()) - else: - tok = int(tok) - - piece = TOKENIZER.decode([tok]) - if piece in ("<|endoftext|>", "<|im_end|>"): - break - await out_msg.stream_token(piece) + for tok in generate_text_simple_stream( + model=MODEL, + token_ids=input_ids_tensor, + max_new_tokens=MAX_NEW_TOKENS, + eos_token_id=TOKENIZER.eos_token_id + ): + token_id = tok.squeeze(0) + piece = TOKENIZER.decode(token_id.tolist()) + await out_msg.stream_token(piece) # 4) Finalize the streamed message await out_msg.update() diff --git a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface.py b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface.py index 2eac62b..926a817 100644 --- a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface.py +++ b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface.py @@ -123,23 +123,15 @@ async def main(message: chainlit.Message): await out_msg.send() # 3) Stream generation - with torch.no_grad(): - for tok in generate_text_simple_stream( - model=MODEL, - token_ids=input_ids_tensor, - max_new_tokens=MAX_NEW_TOKENS, - ): - # Normalize token to int - if torch.is_tensor(tok): - tok = int(tok.view(-1)[0].item()) - else: - tok = int(tok) - - piece = TOKENIZER.decode([tok]) - if piece in ("<|endoftext|>", "<|im_end|>"): - break - - await out_msg.stream_token(piece) + for tok in generate_text_simple_stream( + model=MODEL, + token_ids=input_ids_tensor, + max_new_tokens=MAX_NEW_TOKENS, + eos_token_id=TOKENIZER.eos_token_id + ): + token_id = tok.squeeze(0) + piece = TOKENIZER.decode(token_id.tolist()) + await out_msg.stream_token(piece) # 4) Finalize the streamed message await out_msg.update()