Exercise solution for LoRA instruction finetuning (#240)

This commit is contained in:
Sebastian Raschka
2024-06-22 10:05:32 -05:00
committed by GitHub
parent ec5baa1f33
commit 72f46297d9
2 changed files with 188 additions and 4 deletions

View File

@@ -8,6 +8,7 @@
from functools import partial
from importlib.metadata import version
import json
import math
import os
import re
import time
@@ -107,6 +108,41 @@ class InstructionDatasetPhi(Dataset):
return len(self.data)
class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)
def forward(self, x):
return self.linear(x) + self.lora(x)
class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x
def replace_linear_with_lora(model, rank, alpha):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
# Replace the Linear layer with LinearWithLoRA
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
# Recursively apply the same function to child modules
replace_linear_with_lora(module, rank, alpha)
def custom_collate_fn(
batch,
pad_token_id=50256,
@@ -256,7 +292,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
# plt.show()
def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False, lora=False):
#######################################
# Print package versions
#######################################
@@ -379,6 +415,21 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
print("Loaded model:", CHOOSE_MODEL)
print(50*"-")
if lora:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters before: {total_params:,}")
for param in model.parameters():
param.requires_grad = False
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters after: {total_params:,}")
replace_linear_with_lora(model, rank=16, alpha=16)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable LoRA parameters: {total_params:,}")
model.to(device)
#######################################
# Finetuning the model
#######################################
@@ -418,7 +469,9 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
if phi3_prompt:
plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
if not any([mask_instructions, alpaca52k, phi3_prompt]):
if lora:
plot_name = plot_name.replace(".pdf", "-lora.pdf")
if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
plot_name = plot_name.replace(".pdf", "-baseline.pdf")
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
@@ -460,7 +513,10 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
if phi3_prompt:
test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
file_name = file_name.replace(".pth", "-phi3-prompt.pth")
if not any([mask_instructions, alpaca52k, phi3_prompt]):
if lora:
test_data_path = test_data_path.replace(".json", "-lora.json")
file_name = file_name.replace(".pth", "-lora.pth")
if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
test_data_path = test_data_path.replace(".json", "-baseline.json")
file_name = file_name.replace(".pth", "-baseline.pth")
@@ -479,7 +535,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Instruction finetune a GPT model"
)
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt"}
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt", "lora"}
parser.add_argument(
"--exercise_solution",
type=str,
@@ -498,5 +554,7 @@ if __name__ == "__main__":
main(alpaca52k=True)
elif args.exercise_solution == "phi3_prompt":
main(phi3_prompt=True)
elif args.exercise_solution == "lora":
main(lora=True)
else:
raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")