Improve KV cache code for torch.compile (#705)

* Improve KV cache code for torch.compile

* cleanup

* cleanup
This commit is contained in:
Sebastian Raschka
2025-06-23 18:08:49 -05:00
committed by GitHub
parent 6522be94be
commit 81eda38d3b
8 changed files with 593 additions and 315 deletions

View File

@@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),diagonal=1),
torch.triu(torch.ones(context_length, context_length), diagonal=1),
persistent=False
)