From 192bdc3501e5aeb012f29070b21d7d45027ecab1 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 5 Aug 2024 18:27:20 -0500 Subject: [PATCH] improve gradient accumulation (#300) --- ch06/02_bonus_additional-experiments/additional-experiments.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ch06/02_bonus_additional-experiments/additional-experiments.py b/ch06/02_bonus_additional-experiments/additional-experiments.py index 6246c61..bdb94b3 100644 --- a/ch06/02_bonus_additional-experiments/additional-experiments.py +++ b/ch06/02_bonus_additional-experiments/additional-experiments.py @@ -259,7 +259,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, loss.backward() # Calculate loss gradients # Use gradient accumulation if accumulation_steps > 1 - if batch_idx % accumulation_steps == 0: + is_update_step = ((batch_idx + 1) % accumulation_steps == 0) or ((batch_idx + 1) == len(train_loader)) + if is_update_step: optimizer.step() # Update model weights using loss gradients optimizer.zero_grad() # Reset loss gradients from previous batch iteration