mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
simplify
This commit is contained in:
@@ -143,23 +143,15 @@ async def main(message: chainlit.Message):
|
||||
await out_msg.send()
|
||||
|
||||
# 3) Stream generation
|
||||
with torch.no_grad():
|
||||
for tok in generate_text_simple_stream(
|
||||
model=MODEL,
|
||||
token_ids=input_ids_tensor,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
# eos_token_id=TOKENIZER.eos_token_id
|
||||
):
|
||||
# Normalize token to int
|
||||
if torch.is_tensor(tok):
|
||||
tok = int(tok.view(-1)[0].item())
|
||||
else:
|
||||
tok = int(tok)
|
||||
|
||||
piece = TOKENIZER.decode([tok])
|
||||
if piece in ("<|endoftext|>", "<|im_end|>"):
|
||||
break
|
||||
await out_msg.stream_token(piece)
|
||||
for tok in generate_text_simple_stream(
|
||||
model=MODEL,
|
||||
token_ids=input_ids_tensor,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
eos_token_id=TOKENIZER.eos_token_id
|
||||
):
|
||||
token_id = tok.squeeze(0)
|
||||
piece = TOKENIZER.decode(token_id.tolist())
|
||||
await out_msg.stream_token(piece)
|
||||
|
||||
# 4) Finalize the streamed message
|
||||
await out_msg.update()
|
||||
|
||||
@@ -123,23 +123,15 @@ async def main(message: chainlit.Message):
|
||||
await out_msg.send()
|
||||
|
||||
# 3) Stream generation
|
||||
with torch.no_grad():
|
||||
for tok in generate_text_simple_stream(
|
||||
model=MODEL,
|
||||
token_ids=input_ids_tensor,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
):
|
||||
# Normalize token to int
|
||||
if torch.is_tensor(tok):
|
||||
tok = int(tok.view(-1)[0].item())
|
||||
else:
|
||||
tok = int(tok)
|
||||
|
||||
piece = TOKENIZER.decode([tok])
|
||||
if piece in ("<|endoftext|>", "<|im_end|>"):
|
||||
break
|
||||
|
||||
await out_msg.stream_token(piece)
|
||||
for tok in generate_text_simple_stream(
|
||||
model=MODEL,
|
||||
token_ids=input_ids_tensor,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
eos_token_id=TOKENIZER.eos_token_id
|
||||
):
|
||||
token_id = tok.squeeze(0)
|
||||
piece = TOKENIZER.decode(token_id.tolist())
|
||||
await out_msg.stream_token(piece)
|
||||
|
||||
# 4) Finalize the streamed message
|
||||
await out_msg.update()
|
||||
|
||||
Reference in New Issue
Block a user