add toggle for qkv_bias

This commit is contained in:
rasbt
2024-01-17 07:50:57 -06:00
parent 0074c98968
commit 92896d817c
4 changed files with 114 additions and 108 deletions

View File

@@ -971,12 +971,12 @@
"source": [
"class SelfAttention_v2(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out):\n",
" def __init__(self, d_in, d_out, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
"\n",
" def forward(self, x):\n",
" keys = self.W_key(x)\n",
@@ -1397,12 +1397,12 @@
"source": [
"class CausalAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout):\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
"\n",
@@ -1504,10 +1504,10 @@
"source": [
"class MultiHeadAttentionWrapper(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalAttention(d_in, d_out, block_size, dropout) \n",
" [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
"\n",
@@ -1623,7 +1623,7 @@
],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
"\n",
@@ -1631,9 +1631,9 @@
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
"\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",

View File

@@ -62,7 +62,7 @@
" return self.input_ids[idx], self.target_ids[idx]\n",
"\n",
"\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
@@ -70,7 +70,7 @@
" dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
"\n",
" # Create dataloader\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
"\n",
" return dataloader\n",
"\n",
@@ -155,12 +155,12 @@
"source": [
"class CausalSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout):\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
"\n",
@@ -181,10 +181,10 @@
"\n",
"\n",
"class MultiHeadAttentionWrapper(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
" [CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
@@ -241,7 +241,7 @@
"outputs": [],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
"\n",
@@ -249,9 +249,9 @@
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
"\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",