n_heads × d_head -> d_head × d_head in DeltaNet (#903)

Clarified the explanation of the memory size calculation for `KV_cache_DeltaNet` and updated the quadratic term from `n_heads × d_head` to `d_head × d_head`.
This commit is contained in:
Sebastian Raschka
2025-11-05 18:28:37 -06:00
committed by GitHub
parent 488bef7e3f
commit bcc73f731d

View File

@@ -331,7 +331,7 @@ For the simplified DeltaNet version implemented above, we have:
KV_cache_DeltaNet = batch_size × n_heads × d_head × d_head × bytes
```
Note that the `KV_cache_DeltaNet` memory size doesn't have a context length (`n_tokens`) dependency. Also, we have only the memory state S that we store instead of separate keys and values, hence `2 × bytes` becomes just `bytes`. However, note that we now have a quadratic `n_heads × d_head` in here. This comes from the state :
Note that the `KV_cache_DeltaNet` memory size doesn't have a context length (`n_tokens`) dependency. Also, we have only the memory state S that we store instead of separate keys and values, hence `2 × bytes` becomes just `bytes`. However, note that we now have a quadratic `d_head × d_head` in here. This comes from the state :
```
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)