mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Set up basic test gh worklows (#79)
* Set up basic test gh worklows * update file paths * env check * add env check * Update requirements.txt * simplify * upd
This commit is contained in:
committed by
GitHub
parent
9d6da22ebb
commit
ca96abac8a
33
ch04/01_main-chapter-code/tests.py
Normal file
33
ch04/01_main-chapter-code/tests.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from gpt import main
|
||||
|
||||
expected = """
|
||||
==================================================
|
||||
IN
|
||||
==================================================
|
||||
|
||||
Input text: Hello, I am
|
||||
Encoded input text: [15496, 11, 314, 716]
|
||||
encoded_tensor.shape: torch.Size([1, 4])
|
||||
|
||||
|
||||
==================================================
|
||||
OUT
|
||||
==================================================
|
||||
|
||||
Output: tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267,
|
||||
49706, 43231, 47062, 34657]])
|
||||
Output length: 14
|
||||
Output text: Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous
|
||||
"""
|
||||
|
||||
|
||||
def test_main(capsys):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
|
||||
# Normalize line endings and strip trailing whitespace from each line
|
||||
normalized_expected = '\n'.join(line.rstrip() for line in expected.splitlines())
|
||||
normalized_output = '\n'.join(line.rstrip() for line in captured.out.splitlines())
|
||||
|
||||
# Compare normalized strings
|
||||
assert normalized_output == normalized_expected
|
||||
Reference in New Issue
Block a user