update lora experiments

This commit is contained in:
rasbt
2024-08-16 08:57:46 -05:00
parent b2858a91c5
commit e7cb2ebd8d
2 changed files with 51 additions and 25 deletions

View File

@@ -46,6 +46,21 @@ class LinearWithLoRA(torch.nn.Module):
return self.linear(x) + self.lora(x)
# This LoRA code is equivalent to LinearWithLoRA
class LinearWithLoRAMerged(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)
def forward(self, x):
lora = self.lora.A @ self.lora.B
combined_weight = self.linear.weight + self.lora.alpha*lora.T
return torch.nn.functional.linear(x, combined_weight, self.linear.bias)
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
self.data = pd.read_csv(csv_file)
@@ -295,11 +310,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
return train_losses, val_losses, train_accs, val_accs, examples_seen
def replace_linear_with_lora(model, rank, alpha):
def replace_linear_with_lora(model, rank, alpha, alternative=False):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
# Replace the Linear layer with LinearWithLoRA
setattr(model, name, LinearWithLoRA(module, rank, alpha))
if alternative:
setattr(model, name, LinearWithLoRAMerged(module, rank, alpha))
else:
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
# Recursively apply the same function to child modules
replace_linear_with_lora(module, rank, alpha)
@@ -330,7 +348,7 @@ if __name__ == "__main__":
type=str,
default="last_block",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora'."
"Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora', 'lora_alternative'."
)
)
parser.add_argument(
@@ -474,8 +492,12 @@ if __name__ == "__main__":
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
elif args.trainable_layers == "lora":
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
elif args.trainable_layers in ("lora", "lora_alternative"):
if args.trainable_layers == "lora_alternative":
alternative = True
else:
alternative = False
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, alternative=alternative)
else:
raise ValueError("Invalid --trainable_layers argument.")