mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
committed by
GitHub
parent
e316cafd9f
commit
9d6da22ebb
@@ -91,11 +91,11 @@ def prepare_dataset():
|
||||
train_loader = DataLoader(
|
||||
dataset=train_ds,
|
||||
batch_size=2,
|
||||
shuffle=False, # NEW: False because of DistributedSampler below
|
||||
shuffle=False, # NEW: False because of DistributedSampler below
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
# NEW: chunk batches across GPUs without overlapping samples:
|
||||
sampler=DistributedSampler(train_ds) # NEW
|
||||
sampler=DistributedSampler(train_ds) # NEW
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
dataset=test_ds,
|
||||
@@ -108,14 +108,14 @@ def prepare_dataset():
|
||||
# NEW: wrapper
|
||||
def main(rank, world_size, num_epochs):
|
||||
|
||||
ddp_setup(rank, world_size) # NEW: initialize process groups
|
||||
ddp_setup(rank, world_size) # NEW: initialize process groups
|
||||
|
||||
train_loader, test_loader = prepare_dataset()
|
||||
model = NeuralNetwork(num_inputs=2, num_outputs=2)
|
||||
model.to(rank)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
|
||||
|
||||
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
|
||||
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
|
||||
# the core model is now accessible as model.module
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
@@ -123,15 +123,15 @@ def main(rank, world_size, num_epochs):
|
||||
model.train()
|
||||
for features, labels in train_loader:
|
||||
|
||||
features, labels = features.to(rank), labels.to(rank) # New: use rank
|
||||
features, labels = features.to(rank), labels.to(rank) # New: use rank
|
||||
logits = model(features)
|
||||
loss = F.cross_entropy(logits, labels) # Loss function
|
||||
loss = F.cross_entropy(logits, labels) # Loss function
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
### LOGGING
|
||||
# LOGGING
|
||||
print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
|
||||
f" | Batchsize {labels.shape[0]:03d}"
|
||||
f" | Train/Val Loss: {loss:.2f}")
|
||||
@@ -142,7 +142,7 @@ def main(rank, world_size, num_epochs):
|
||||
test_acc = compute_accuracy(model, test_loader, device=rank)
|
||||
print(f"[GPU{rank}] Test accuracy", test_acc)
|
||||
|
||||
destroy_process_group() # NEW: cleanly exit distributed mode
|
||||
destroy_process_group() # NEW: cleanly exit distributed mode
|
||||
|
||||
|
||||
def compute_accuracy(model, dataloader, device):
|
||||
|
||||
Reference in New Issue
Block a user