add shape information for clarity

This commit is contained in:
rasbt
2024-02-08 20:16:54 -06:00
parent 3a5fc79b38
commit 5d1d8ce511
3 changed files with 12 additions and 11 deletions

View File

@@ -168,7 +168,7 @@ class TransformerBlock(nn.Module):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_resid(x)
x = x + shortcut # Add the original input back
@@ -200,7 +200,7 @@ class GPTModel(nn.Module):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)