mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
committed by
GitHub
parent
c9271ac427
commit
e07a7abdd5
@@ -149,3 +149,50 @@ class MultiHeadAttention(nn.Module):
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
|
||||
|
||||
######################
|
||||
# Bonus
|
||||
######################
|
||||
|
||||
|
||||
class PyTorchMultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
|
||||
super().__init__()
|
||||
|
||||
assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads
|
||||
self.d_out = d_out
|
||||
|
||||
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
|
||||
self.proj = nn.Linear(d_out, d_out)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, num_tokens, embed_dim = x.shape
|
||||
|
||||
# (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
|
||||
qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
|
||||
|
||||
# (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
|
||||
# (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
|
||||
queries, keys, values = qkv
|
||||
|
||||
use_dropout = 0. if not self.training else self.dropout
|
||||
|
||||
context_vec = nn.functional.scaled_dot_product_attention(
|
||||
queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
|
||||
|
||||
context_vec = self.proj(context_vec)
|
||||
|
||||
return context_vec
|
||||
|
||||
Reference in New Issue
Block a user