mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
add toggle for qkv_bias
This commit is contained in:
@@ -971,12 +971,12 @@
|
||||
"source": [
|
||||
"class SelfAttention_v2(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out):\n",
|
||||
" def __init__(self, d_in, d_out, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" keys = self.W_key(x)\n",
|
||||
@@ -1397,12 +1397,12 @@
|
||||
"source": [
|
||||
"class CausalAttention(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.dropout = nn.Dropout(dropout) # New\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
||||
"\n",
|
||||
@@ -1504,10 +1504,10 @@
|
||||
"source": [
|
||||
"class MultiHeadAttentionWrapper(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.heads = nn.ModuleList(\n",
|
||||
" [CausalAttention(d_in, d_out, block_size, dropout) \n",
|
||||
" [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
|
||||
" for _ in range(num_heads)]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@@ -1623,7 +1623,7 @@
|
||||
],
|
||||
"source": [
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
||||
"\n",
|
||||
@@ -1631,9 +1631,9 @@
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
|
||||
"\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
|
||||
|
||||
Reference in New Issue
Block a user