mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
add toggle for qkv_bias
This commit is contained in:
@@ -62,7 +62,7 @@
|
||||
" return self.input_ids[idx], self.target_ids[idx]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
|
||||
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
|
||||
" # Initialize the tokenizer\n",
|
||||
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||
"\n",
|
||||
@@ -70,7 +70,7 @@
|
||||
" dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
|
||||
"\n",
|
||||
" # Create dataloader\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
|
||||
"\n",
|
||||
" return dataloader\n",
|
||||
"\n",
|
||||
@@ -155,12 +155,12 @@
|
||||
"source": [
|
||||
"class CausalSelfAttention(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.dropout = nn.Dropout(dropout) # New\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
||||
"\n",
|
||||
@@ -181,10 +181,10 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"class MultiHeadAttentionWrapper(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.heads = nn.ModuleList(\n",
|
||||
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
|
||||
" [CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
|
||||
" for _ in range(num_heads)]\n",
|
||||
" )\n",
|
||||
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
|
||||
@@ -241,7 +241,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
||||
"\n",
|
||||
@@ -249,9 +249,9 @@
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
|
||||
"\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
|
||||
|
||||
Reference in New Issue
Block a user