mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add LoRA scaling (#823)
This commit is contained in:
committed by
GitHub
parent
fc101b710e
commit
8f3e5b024d
File diff suppressed because one or more lines are too long
@@ -14,9 +14,11 @@ class LoRALayer(torch.nn.Module):
|
|||||||
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
|
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
|
||||||
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
|
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.alpha * (x @ self.A @ self.B)
|
|
||||||
|
x = (self.alpha / self.rank) * (x @ self.A @ self.B)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user