Add GPTModelFast (#584)

* Add GPTModelFast

* update
This commit is contained in:
Sebastian Raschka
2025-03-27 14:00:25 -05:00
committed by GitHub
parent c9271ac427
commit e07a7abdd5
7 changed files with 204 additions and 61 deletions

View File

@@ -4,7 +4,7 @@
# Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch03 import MultiHeadAttention
from llms_from_scratch.ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
import torch
@@ -14,7 +14,15 @@ def test_mha():
d_in = 256
d_out = 16
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)
batch = torch.rand(8, 6, d_in)
context_vecs = mha(batch)
context_vecs.shape == torch.Size([8, 6, d_out])
# Test bonus class
mha = PyTorchMultiHeadAttention(d_in, d_out, num_heads=2)
batch = torch.rand(8, 6, d_in)
context_vecs = mha(batch)