This commit is contained in:
rasbt
2025-09-01 22:15:47 -05:00
parent 643f800a94
commit 9ea2c57c5f
3 changed files with 19 additions and 34 deletions

1
.gitignore vendored
View File

@@ -144,6 +144,7 @@ ch0?/0?_user_interface/.files
*.lock
# Temporary and OS-related files
chainlit.md
Untitled.ipynb
.DS_Store

View File

@@ -143,23 +143,15 @@ async def main(message: chainlit.Message):
await out_msg.send()
# 3) Stream generation
with torch.no_grad():
for tok in generate_text_simple_stream(
model=MODEL,
token_ids=input_ids_tensor,
max_new_tokens=MAX_NEW_TOKENS,
# eos_token_id=TOKENIZER.eos_token_id
):
# Normalize token to int
if torch.is_tensor(tok):
tok = int(tok.view(-1)[0].item())
else:
tok = int(tok)
piece = TOKENIZER.decode([tok])
if piece in ("<|endoftext|>", "<|im_end|>"):
break
await out_msg.stream_token(piece)
for tok in generate_text_simple_stream(
model=MODEL,
token_ids=input_ids_tensor,
max_new_tokens=MAX_NEW_TOKENS,
eos_token_id=TOKENIZER.eos_token_id
):
token_id = tok.squeeze(0)
piece = TOKENIZER.decode(token_id.tolist())
await out_msg.stream_token(piece)
# 4) Finalize the streamed message
await out_msg.update()

View File

@@ -123,23 +123,15 @@ async def main(message: chainlit.Message):
await out_msg.send()
# 3) Stream generation
with torch.no_grad():
for tok in generate_text_simple_stream(
model=MODEL,
token_ids=input_ids_tensor,
max_new_tokens=MAX_NEW_TOKENS,
):
# Normalize token to int
if torch.is_tensor(tok):
tok = int(tok.view(-1)[0].item())
else:
tok = int(tok)
piece = TOKENIZER.decode([tok])
if piece in ("<|endoftext|>", "<|im_end|>"):
break
await out_msg.stream_token(piece)
for tok in generate_text_simple_stream(
model=MODEL,
token_ids=input_ids_tensor,
max_new_tokens=MAX_NEW_TOKENS,
eos_token_id=TOKENIZER.eos_token_id
):
token_id = tok.squeeze(0)
piece = TOKENIZER.decode(token_id.tolist())
await out_msg.stream_token(piece)
# 4) Finalize the streamed message
await out_msg.update()