small readability updates

This commit is contained in:
rasbt
2024-01-14 11:58:42 -06:00
parent c79499572f
commit a7b4880179
3 changed files with 26 additions and 25 deletions

View File

@@ -61,7 +61,7 @@
" values = x @ self.W_value\n",
" \n",
" attn_scores = queries @ keys.T # omega\n",
" attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
"\n",
" context_vec = attn_weights @ values\n",
" return context_vec\n",
@@ -92,7 +92,7 @@
" values = self.W_value(x)\n",
" \n",
" attn_scores = queries @ keys.T\n",
" attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
"\n",
" context_vec = attn_weights @ values\n",
" return context_vec\n",