mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
update lora experiments
This commit is contained in:
@@ -46,6 +46,21 @@ class LinearWithLoRA(torch.nn.Module):
|
||||
return self.linear(x) + self.lora(x)
|
||||
|
||||
|
||||
# This LoRA code is equivalent to LinearWithLoRA
|
||||
class LinearWithLoRAMerged(torch.nn.Module):
|
||||
def __init__(self, linear, rank, alpha):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.lora = LoRALayer(
|
||||
linear.in_features, linear.out_features, rank, alpha
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
lora = self.lora.A @ self.lora.B
|
||||
combined_weight = self.linear.weight + self.lora.alpha*lora.T
|
||||
return torch.nn.functional.linear(x, combined_weight, self.linear.bias)
|
||||
|
||||
|
||||
class SpamDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
@@ -295,11 +310,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
return train_losses, val_losses, train_accs, val_accs, examples_seen
|
||||
|
||||
|
||||
def replace_linear_with_lora(model, rank, alpha):
|
||||
def replace_linear_with_lora(model, rank, alpha, alternative=False):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Replace the Linear layer with LinearWithLoRA
|
||||
setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
||||
if alternative:
|
||||
setattr(model, name, LinearWithLoRAMerged(module, rank, alpha))
|
||||
else:
|
||||
setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
||||
else:
|
||||
# Recursively apply the same function to child modules
|
||||
replace_linear_with_lora(module, rank, alpha)
|
||||
@@ -330,7 +348,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="last_block",
|
||||
help=(
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora'."
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora', 'lora_alternative'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -474,8 +492,12 @@ if __name__ == "__main__":
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "lora":
|
||||
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
|
||||
elif args.trainable_layers in ("lora", "lora_alternative"):
|
||||
if args.trainable_layers == "lora_alternative":
|
||||
alternative = True
|
||||
else:
|
||||
alternative = False
|
||||
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, alternative=alternative)
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user