mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
sklearn baseline and roberta-large update
This commit is contained in:
@@ -235,7 +235,14 @@ if __name__ == "__main__":
|
||||
"Number of epochs."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help=(
|
||||
"Learning rate."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token == "first":
|
||||
@@ -346,7 +353,7 @@ if __name__ == "__main__":
|
||||
|
||||
start_time = time.time()
|
||||
torch.manual_seed(123)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
|
||||
|
||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
|
||||
Reference in New Issue
Block a user