Add LoRA scaling (#823)

This commit is contained in:
Sebastian Raschka
2025-09-14 11:57:55 -05:00
committed by GitHub
parent fc101b710e
commit 8f3e5b024d
2 changed files with 44 additions and 38 deletions

View File

@@ -14,9 +14,11 @@ class LoRALayer(torch.nn.Module):
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
self.rank = rank
def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
x = (self.alpha / self.rank) * (x @ self.A @ self.B)
return x