mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
committed by
GitHub
parent
c9271ac427
commit
e07a7abdd5
@@ -4,7 +4,7 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
|
||||
from llms_from_scratch.ch03 import MultiHeadAttention
|
||||
from llms_from_scratch.ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
|
||||
import torch
|
||||
|
||||
|
||||
@@ -14,7 +14,15 @@ def test_mha():
|
||||
d_in = 256
|
||||
d_out = 16
|
||||
|
||||
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
|
||||
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)
|
||||
|
||||
batch = torch.rand(8, 6, d_in)
|
||||
context_vecs = mha(batch)
|
||||
|
||||
context_vecs.shape == torch.Size([8, 6, d_out])
|
||||
|
||||
# Test bonus class
|
||||
mha = PyTorchMultiHeadAttention(d_in, d_out, num_heads=2)
|
||||
|
||||
batch = torch.rand(8, 6, d_in)
|
||||
context_vecs = mha(batch)
|
||||
|
||||
Reference in New Issue
Block a user