mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
update pretrain progress bar tutorial
This commit is contained in:
@@ -3,9 +3,7 @@ import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
|
||||
# Create a simple toy dataset example, normally this
|
||||
# would be doing custom class with __getitem__ etc,
|
||||
# which we have done in custom dataset tutorials
|
||||
# Create a simple toy dataset
|
||||
x = torch.randn((1000, 3, 224, 224))
|
||||
y = torch.randint(low=0, high=10, size=(1000, 1))
|
||||
ds = TensorDataset(x, y)
|
||||
@@ -13,12 +11,12 @@ loader = DataLoader(ds, batch_size=8)
|
||||
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Conv2d(3, 10, kernel_size=3, padding=1, stride=1),
|
||||
nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, padding=1, stride=1),
|
||||
nn.Flatten(),
|
||||
nn.Linear(10*224*224, 10),
|
||||
nn.Linear(10 * 224 * 224, 10),
|
||||
)
|
||||
|
||||
NUM_EPOCHS = 100
|
||||
NUM_EPOCHS = 10
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
loop = tqdm(loader)
|
||||
for idx, (x, y) in enumerate(loop):
|
||||
@@ -35,7 +33,3 @@ for epoch in range(NUM_EPOCHS):
|
||||
loop.set_postfix(loss=torch.rand(1).item(), acc=torch.rand(1).item())
|
||||
|
||||
# There you go. Hope it was useful :)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user