add toggle for qkv_bias

This commit is contained in:
rasbt
2024-01-17 07:50:57 -06:00
parent 0074c98968
commit 92896d817c
4 changed files with 114 additions and 108 deletions

View File

@@ -971,12 +971,12 @@
"source": [
"class SelfAttention_v2(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out):\n",
" def __init__(self, d_in, d_out, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
"\n",
" def forward(self, x):\n",
" keys = self.W_key(x)\n",
@@ -1397,12 +1397,12 @@
"source": [
"class CausalAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout):\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
"\n",
@@ -1504,10 +1504,10 @@
"source": [
"class MultiHeadAttentionWrapper(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalAttention(d_in, d_out, block_size, dropout) \n",
" [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
"\n",
@@ -1623,7 +1623,7 @@
],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
"\n",
@@ -1631,9 +1631,9 @@
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
"\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",