mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
n_heads × d_head -> d_head × d_head in DeltaNet (#903)
Clarified the explanation of the memory size calculation for `KV_cache_DeltaNet` and updated the quadratic term from `n_heads × d_head` to `d_head × d_head`.
This commit is contained in:
committed by
GitHub
parent
488bef7e3f
commit
bcc73f731d
@@ -331,7 +331,7 @@ For the simplified DeltaNet version implemented above, we have:
|
||||
KV_cache_DeltaNet = batch_size × n_heads × d_head × d_head × bytes
|
||||
```
|
||||
|
||||
Note that the `KV_cache_DeltaNet` memory size doesn't have a context length (`n_tokens`) dependency. Also, we have only the memory state S that we store instead of separate keys and values, hence `2 × bytes` becomes just `bytes`. However, note that we now have a quadratic `n_heads × d_head` in here. This comes from the state :
|
||||
Note that the `KV_cache_DeltaNet` memory size doesn't have a context length (`n_tokens`) dependency. Also, we have only the memory state S that we store instead of separate keys and values, hence `2 × bytes` becomes just `bytes`. However, note that we now have a quadratic `d_head × d_head` in here. This comes from the state :
|
||||
|
||||
```
|
||||
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
|
||||
|
||||
Reference in New Issue
Block a user