mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Update DDP-script.py
Fix for-loop
This commit is contained in:
committed by
GitHub
parent
c9dccb0c40
commit
c071ea73f9
@@ -117,25 +117,25 @@ def main(rank, world_size, num_epochs):
|
|||||||
|
|
||||||
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
|
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
|
||||||
# the core model is now accessible as model.module
|
# the core model is now accessible as model.module
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
for features, labels in enumerate(train_loader):
|
for features, labels in train_loader:
|
||||||
|
|
||||||
features, labels = features.to(rank), labels.to(rank) # New: use rank
|
features, labels = features.to(rank), labels.to(rank) # New: use rank
|
||||||
logits = model(features)
|
logits = model(features)
|
||||||
loss = F.cross_entropy(logits, labels) # Loss function
|
loss = F.cross_entropy(logits, labels) # Loss function
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
### LOGGING
|
### LOGGING
|
||||||
print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
|
print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
|
||||||
f" | Batchsize {labels.shape[0]:03d}"
|
f" | Batchsize {labels.shape[0]:03d}"
|
||||||
f" | Train/Val Loss: {loss:.2f}")
|
f" | Train/Val Loss: {loss:.2f}")
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
train_acc = compute_accuracy(model, train_loader, device=rank)
|
train_acc = compute_accuracy(model, train_loader, device=rank)
|
||||||
print(f"[GPU{rank}] Training accuracy", train_acc)
|
print(f"[GPU{rank}] Training accuracy", train_acc)
|
||||||
@@ -175,4 +175,3 @@ if __name__ == "__main__":
|
|||||||
world_size = torch.cuda.device_count()
|
world_size = torch.cuda.device_count()
|
||||||
mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
|
mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
|
||||||
# nprocs=world_size spawns one process per GPU
|
# nprocs=world_size spawns one process per GPU
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user