mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
committed by
GitHub
parent
e316cafd9f
commit
9d6da22ebb
@@ -36,7 +36,7 @@ class GPTDatasetV1(Dataset):
|
||||
return self.input_ids[idx], self.target_ids[idx]
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
@@ -80,7 +80,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# We implicitly split the matrix by adding a `num_heads` dimension
|
||||
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
||||
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
|
||||
@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
|
||||
# Shape: (b, num_tokens, num_heads, head_dim)
|
||||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
|
||||
@@ -135,7 +135,7 @@ class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
||||
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
||||
(x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
|
||||
@@ -161,7 +161,7 @@ class TransformerBlock(nn.Module):
|
||||
d_in=cfg["emb_dim"],
|
||||
d_out=cfg["emb_dim"],
|
||||
block_size=cfg["ctx_len"],
|
||||
num_heads=cfg["n_heads"],
|
||||
num_heads=cfg["n_heads"],
|
||||
dropout=cfg["drop_rate"],
|
||||
qkv_bias=cfg["qkv_bias"])
|
||||
self.ff = FeedForward(cfg)
|
||||
@@ -227,7 +227,7 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
# Focus only on the last time step
|
||||
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
||||
logits = logits[:, -1, :]
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Get the idx of the vocab entry with the highest logits value
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
||||
@@ -315,4 +315,4 @@ def text_to_token_ids(text, tokenizer):
|
||||
|
||||
def token_ids_to_text(token_ids, tokenizer):
|
||||
flat = token_ids.squeeze(0) # remove batch dimension
|
||||
return tokenizer.decode(flat.tolist())
|
||||
return tokenizer.decode(flat.tolist())
|
||||
|
||||
Reference in New Issue
Block a user