mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Exercise solution for LoRA instruction finetuning (#240)
This commit is contained in:
committed by
GitHub
parent
ec5baa1f33
commit
72f46297d9
@@ -8,6 +8,7 @@
|
||||
from functools import partial
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@@ -107,6 +108,41 @@ class InstructionDatasetPhi(Dataset):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class LinearWithLoRA(torch.nn.Module):
|
||||
def __init__(self, linear, rank, alpha):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.lora = LoRALayer(
|
||||
linear.in_features, linear.out_features, rank, alpha
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + self.lora(x)
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, rank, alpha):
|
||||
super().__init__()
|
||||
self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
|
||||
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
|
||||
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x):
|
||||
x = self.alpha * (x @ self.A @ self.B)
|
||||
return x
|
||||
|
||||
|
||||
def replace_linear_with_lora(model, rank, alpha):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Replace the Linear layer with LinearWithLoRA
|
||||
setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
||||
else:
|
||||
# Recursively apply the same function to child modules
|
||||
replace_linear_with_lora(module, rank, alpha)
|
||||
|
||||
|
||||
def custom_collate_fn(
|
||||
batch,
|
||||
pad_token_id=50256,
|
||||
@@ -256,7 +292,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
|
||||
# plt.show()
|
||||
|
||||
|
||||
def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
||||
def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False, lora=False):
|
||||
#######################################
|
||||
# Print package versions
|
||||
#######################################
|
||||
@@ -379,6 +415,21 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
||||
print("Loaded model:", CHOOSE_MODEL)
|
||||
print(50*"-")
|
||||
|
||||
if lora:
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Total trainable parameters before: {total_params:,}")
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Total trainable parameters after: {total_params:,}")
|
||||
replace_linear_with_lora(model, rank=16, alpha=16)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Total trainable LoRA parameters: {total_params:,}")
|
||||
model.to(device)
|
||||
|
||||
#######################################
|
||||
# Finetuning the model
|
||||
#######################################
|
||||
@@ -418,7 +469,9 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
||||
plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
|
||||
if phi3_prompt:
|
||||
plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
|
||||
if not any([mask_instructions, alpaca52k, phi3_prompt]):
|
||||
if lora:
|
||||
plot_name = plot_name.replace(".pdf", "-lora.pdf")
|
||||
if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
|
||||
plot_name = plot_name.replace(".pdf", "-baseline.pdf")
|
||||
|
||||
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
|
||||
@@ -460,7 +513,10 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
||||
if phi3_prompt:
|
||||
test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
|
||||
file_name = file_name.replace(".pth", "-phi3-prompt.pth")
|
||||
if not any([mask_instructions, alpaca52k, phi3_prompt]):
|
||||
if lora:
|
||||
test_data_path = test_data_path.replace(".json", "-lora.json")
|
||||
file_name = file_name.replace(".pth", "-lora.pth")
|
||||
if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
|
||||
test_data_path = test_data_path.replace(".json", "-baseline.json")
|
||||
file_name = file_name.replace(".pth", "-baseline.pth")
|
||||
|
||||
@@ -479,7 +535,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Instruction finetune a GPT model"
|
||||
)
|
||||
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt"}
|
||||
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt", "lora"}
|
||||
parser.add_argument(
|
||||
"--exercise_solution",
|
||||
type=str,
|
||||
@@ -498,5 +554,7 @@ if __name__ == "__main__":
|
||||
main(alpaca52k=True)
|
||||
elif args.exercise_solution == "phi3_prompt":
|
||||
main(phi3_prompt=True)
|
||||
elif args.exercise_solution == "lora":
|
||||
main(lora=True)
|
||||
else:
|
||||
raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")
|
||||
|
||||
Reference in New Issue
Block a user