Initial commit

This commit is contained in:
Aladdin Persson
2021-01-30 21:49:15 +01:00
commit 65b8c80495
432 changed files with 1290844 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
# Exploring the MNIST dataset with PyTorch
The goal of this small project of mine is to learn different models and then try and see what kind of test accuracies we can get on the MNIST dataset. I checked some popular models (LeNet, VGG, Inception net, ResNet) and likely I will try more out in the future as I learn more network architectures. I used an exponential learning rate decay and data augmentation, in the beginning I was just using every data augmentation other people were using but I learned that using RandomHorizontalFlip when learning to recognize digits might not be so useful (heh). I also used a lambda/weight decay of pretty standard 5e-4. My thinking during training was first that I split into a validationset of about 10000 examples and made sure that it was getting high accuracies on validationset with current hyperparameters. After making sure that it wasn't just overfitting the training set, I changed so that the model used all of the training examples (60000) and then when finished training to about ~99.9% training accuracy I tested on the test set.
## Accuracy
| Model | Number of epochs | Training set acc. | Test set acc. |
| ----------------- | ----------- | ----------------- | ----------- |
| [LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) | 150 | 99.69% | 99.12% |
| [VGG13](https://arxiv.org/abs/1409.1556) | 100 | 99.95% | 99.67% |
| [VGG16](https://arxiv.org/abs/1409.1556) | 100 | 99.92% | 99.68% |
| [GoogLeNet](https://arxiv.org/abs/1409.4842) | 100 | 99.90% | 99.71% |
| [ResNet101](https://arxiv.org/abs/1512.03385) | 100 | 99.90% | 99.68% |
TODO: MobileNet, ResNext, SqueezeNet, .., ?
### Comments and things to improve
I believe LeNet has more potential as it's not really overfitting the training set that well and needs more epochs. I believe that in the original paper by LeCun et. al. (1998) showed that they achieved about 99.1% test accuracy which is similar to my results but we also need to remember the limitations that were back then. I do think training it for a bit longer to make it ~99.8-99.9% on training set would get it up to perhaps 99.2-99.3% test accuracy if we're lucky. So far the other models I think have performed quite well and is close, at least from my understanding, to current state of the art. If you would like to really maximize accuracy you would train an ensemble of models and then average their predictions to achieve better accuracy but I've not done that here as I don't think it's that interesting. This was mostly to learn different network architectures and to then check if they work as intended. If you find anything that I can improve or any mistakes, please tell me what and I'll do my best to fix it!
### How to run
```bash
usage: train.py [-h] [--resume PATH] [--lr LR] [--weight-decay R]
[--momentum R] [--epochs N] [--batch-size N]
[--log-interval N] [--seed S] [--number-workers S]
[--init-padding S] [--create-validationset] [--save-model]
PyTorch MNIST
optional arguments:
--resume PATH Saved model. (ex: PATH = checkpoint/mnist_LeNet.pth.tar)
--batch-size N (ex: --batch-size 64), default is 128.
--epochs N (ex: --epochs 10) default is 100.
--lr LR learning rate (ex: --lr 0.01), default is 0.001.
--momentum M SGD w momentum (ex: --momentum 0.5), default is 0.9.
--seed S random seed (ex: --seed 3), default is 1.
--log-interval N print accuracy ever N mini-batches, ex (--log-interval 50), default 240.
--init-padding S Initial padding on images (ex: --init-padding 5), default is 2 to make 28x28 into 32x32.
--create-validation to create validationset
--save-model to save weights
--weight-decay R What weight decay you want (ex: --weight-decay 1e-4), default 1e-5.
--number-workers S How many num workers you want in PyTorch (ex --number-workers 2), default is 0.
Example of a run is:
python train.py --save-model --resume checkpoint/mnist_LeNet.pth.tar --weight-decay 1e-5 --number-workers 2
```

View File

@@ -0,0 +1,109 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Inception(nn.Module):
def __init__(
self, in_channels, out1x1, out3x3reduced, out3x3, out5x5reduced, out5x5, outpool
):
super().__init__()
self.branch_1 = BasicConv2d(in_channels, out1x1, kernel_size=1, stride=1)
self.branch_2 = nn.Sequential(
BasicConv2d(in_channels, out3x3reduced, kernel_size=1),
BasicConv2d(out3x3reduced, out3x3, kernel_size=3, padding=1),
)
# Is in the original googLeNet paper 5x5 conv but in Inception_v2 it has shown to be
# more efficient if you instead do two 3x3 convs which is what I am doing here!
self.branch_3 = nn.Sequential(
BasicConv2d(in_channels, out5x5reduced, kernel_size=1),
BasicConv2d(out5x5reduced, out5x5, kernel_size=3, padding=1),
BasicConv2d(out5x5, out5x5, kernel_size=3, padding=1),
)
self.branch_4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(in_channels, outpool, kernel_size=1),
)
def forward(self, x):
y1 = self.branch_1(x)
y2 = self.branch_2(x)
y3 = self.branch_3(x)
y4 = self.branch_4(x)
return torch.cat([y1, y2, y3, y4], 1)
class GoogLeNet(nn.Module):
def __init__(self, img_channel):
super().__init__()
self.first_layers = nn.Sequential(
BasicConv2d(img_channel, 192, kernel_size=3, padding=1)
)
self._3a = Inception(192, 64, 96, 128, 16, 32, 32)
self._3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self._4a = Inception(480, 192, 96, 208, 16, 48, 64)
self._4b = Inception(512, 160, 112, 224, 24, 64, 64)
self._4c = Inception(512, 128, 128, 256, 24, 64, 64)
self._4d = Inception(512, 112, 144, 288, 32, 64, 64)
self._4e = Inception(528, 256, 160, 320, 32, 128, 128)
self._5a = Inception(832, 256, 160, 320, 32, 128, 128)
self._5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.avgpool = nn.AvgPool2d(kernel_size=8, stride=1)
self.linear = nn.Linear(1024, 10)
def forward(self, x):
out = self.first_layers(x)
out = self._3a(out)
out = self._3b(out)
out = self.maxpool(out)
out = self._4a(out)
out = self._4b(out)
out = self._4c(out)
out = self._4d(out)
out = self._4e(out)
out = self.maxpool(out)
out = self._5a(out)
out = self._5b(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
def test():
net = GoogLeNet(1)
x = torch.randn(3, 1, 32, 32)
y = net(x)
print(y.size())
# test()

View File

@@ -0,0 +1,4 @@
from networks.vgg import VGG
from networks.lenet import LeNet
from networks.resnet import ResNet, residual_template, ResNet50, ResNet101, ResNet152
from networks.googLeNet import BasicConv2d, Inception, GoogLeNet

View File

@@ -0,0 +1,60 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self, in_channels, init_weights=True, num_classes=10):
super(LeNet, self).__init__()
self.num_classes = num_classes
if init_weights:
self._initialize_weights()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
z1 = self.conv1(x) # 6 x 28 x 28
a1 = F.relu(z1) # 6 x 28 x 28
a1 = F.max_pool2d(a1, kernel_size=2, stride=2) # 6 x 14 x 14
z2 = self.conv2(a1) # 16 x 10 x 10
a2 = F.relu(z2) # 16 x 10 x 10
a2 = F.max_pool2d(a2, kernel_size=2, stride=2) # 16 x 5 x 5
flatten_a2 = a2.view(a2.size(0), -1)
z3 = self.fc1(flatten_a2)
a3 = F.relu(z3)
z4 = self.fc2(a3)
a4 = F.relu(z4)
z5 = self.fc3(a4)
return z5
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def test_lenet():
net = LeNet(1)
x = torch.randn(64, 1, 32, 32)
y = net(x)
print(y.size())
test_lenet()

View File

@@ -0,0 +1,151 @@
import torch
import torch.nn as nn
class residual_template(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, identity_downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(
out_channels, out_channels * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.identity_downsample = identity_downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.identity_downsample is not None:
residual = self.identity_downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, residual_template, layers, image_channel, num_classes=10):
self.in_channels = 64
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=image_channel,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(
residual_template, layers[0], channels=64, stride=1
)
self.layer2 = self._make_layer(
residual_template, layers[1], channels=128, stride=2
)
self.layer3 = self._make_layer(
residual_template, layers[2], channels=256, stride=2
)
self.layer4 = self._make_layer(
residual_template, layers[3], channels=512, stride=2
)
self.avgpool = nn.AvgPool2d(kernel_size=4, stride=1)
self.fc = nn.Linear(512 * residual_template.expansion, num_classes)
# initialize weights for conv layers, batch layers
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, residual_template, num_residuals_blocks, channels, stride):
identity_downsample = None
if stride != 1 or self.in_channels != channels * residual_template.expansion:
identity_downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,
channels * residual_template.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(channels * residual_template.expansion),
)
layers = []
layers.append(
residual_template(self.in_channels, channels, stride, identity_downsample)
)
self.in_channels = channels * residual_template.expansion
for i in range(1, num_residuals_blocks):
layers.append(residual_template(self.in_channels, channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def ResNet50(img_channel):
return ResNet(residual_template, [3, 4, 6, 3], img_channel)
def ResNet101(img_channel):
return ResNet(residual_template, [3, 4, 23, 3], img_channel)
def ResNet152(img_channel):
return ResNet(residual_template, [3, 8, 36, 3], img_channel)
def test():
net = ResNet152(img_channel=1)
y = net(torch.randn(64, 1, 32, 32))
print(y.size())
# test()

View File

@@ -0,0 +1,139 @@
import torch
import torch.nn as nn
VGG_types = {
"VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"VGG16": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"VGG19": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
}
class VGG(nn.Module):
def __init__(
self, vgg_type, in_channels, init_weights=True, batch_norm=True, num_classes=10
):
super().__init__()
self.batch_norm = batch_norm
self.in_channels = in_channels
self.layout = self.create_architecture(VGG_types[vgg_type])
self.fc = nn.Linear(512, num_classes)
# self.fcs = nn.Sequential(
# nn.Linear(512* 1 * 1, 4096),
# nn.ReLU(inplace = False),
# nn.Dropout(),
# nn.Linear(4096, 4096),
# nn.ReLU(inplace = False),
# nn.Dropout(),
# nn.Linear(4096, num_classes),
# )
if init_weights:
self._initialize_weights()
def forward(self, x):
out = self.layout(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def create_architecture(self, architecture):
layers = []
for x in architecture:
if type(x) == int:
out_channels = x
conv2d = nn.Conv2d(
self.in_channels, out_channels, kernel_size=3, padding=1
)
if self.batch_norm:
layers += [
conv2d,
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=False),
]
else:
layers += [conv2d, nn.ReLU(inplace=False)]
self.in_channels = out_channels
elif x == "M":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)
def test():
net = VGG("VGG16", 1)
x = torch.randn(64, 1, 32, 32)
y = net(x)
print(y.size())
# test()

View File

@@ -0,0 +1,264 @@
import argparse
import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, SubsetRandomSampler
from networks.import_all_networks import *
from utils.import_utils import *
class Train_MNIST(object):
def __init__(self):
self.best_acc = 0
self.in_channels = 1 # 1 because MNIST is grayscale
self.dataset = mnist_data # Class that is imported from utils that imports data
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float32
self.args = self.prepare_args()
self.transform_train, self.transform_test = self.prepare_transformations()
if self.args.create_validationset:
(
self.loader_train,
self.loader_validation,
self.loader_test,
) = self.prepare_data()
self.data_check_acc = self.loader_validation
else:
self.loader_train, self.loader_test = self.prepare_data()
self.data_check_acc = self.loader_train
def prepare_args(self):
parser = argparse.ArgumentParser(description="PyTorch MNIST")
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--lr",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
)
parser.add_argument(
"--weight-decay",
default=1e-5,
type=float,
metavar="R",
help="L2 regularization lambda",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD with momentum"
)
parser.add_argument(
"--epochs",
type=int,
default=100,
metavar="N",
help="number of epochs to train (default: 100)",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--log-interval",
type=int,
default=240,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--number-workers",
type=int,
default=0,
metavar="S",
help="number of workers (default: 0)",
)
parser.add_argument(
"--init-padding",
type=int,
default=2,
metavar="S",
help=" If use initial padding or not. (default: 2 because mnist 28x28 to make 32x32)",
)
parser.add_argument(
"--create-validationset",
action="store_true",
default=False,
help="If you want to use a validation set (default: False). Default size = 10%",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="If you want to save this model(default: False).",
)
args = parser.parse_args()
return args
def prepare_transformations(self):
transform_train = transforms.Compose(
[
transforms.Pad(self.args.init_padding),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
transform_test = transforms.Compose(
[
transforms.Pad(self.args.init_padding),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
return transform_train, transform_test
def prepare_data(self, shuffle=True):
data = self.dataset(
shuffle,
self.transform_train,
self.transform_test,
self.args.number_workers,
self.args.create_validationset,
self.args.batch_size,
validation_size=0.1,
random_seed=self.args.seed,
)
if self.args.create_validationset:
loader_train, loader_validation, loader_test = data.main()
return loader_train, loader_validation, loader_test
else:
loader_train, loader_test = data.main()
return loader_train, loader_test
def train(self):
criterion = nn.CrossEntropyLoss()
iter = 0
# vis_plotting = visdom_plotting()
loss_list, batch_list, epoch_list, validation_acc_list, training_acc_list = (
[],
[],
[0],
[0],
[0],
)
for epoch in range(self.args.epochs):
for batch_idx, (x, y) in enumerate(self.loader_train):
self.model.train()
x = x.to(device=self.device, dtype=self.dtype)
y = y.to(device=self.device, dtype=torch.long)
scores = self.model(x)
loss = criterion(scores, y)
loss_list.append(loss.item())
batch_list.append(iter + 1)
iter += 1
if batch_idx % self.args.log_interval == 0:
print(f"Batch {batch_idx}, epoch {epoch}, loss = {loss.item()}")
print()
self.model.eval()
train_acc = check_accuracy(self.data_check_acc, self.model)
# validation_acc = self.check_accuracy(self.data_check_acc)
validation_acc = 0
validation_acc_list.append(validation_acc)
training_acc_list.append(train_acc)
epoch_list.append(epoch + 0.5)
print()
print()
# call to plot in visdom
# vis_plotting.create_plot(loss_list, batch_list, validation_acc_list, epoch_list, training_acc_list)
# save checkpoint
if train_acc > self.best_acc and self.args.save_model:
self.best_acc = train_acc
save_checkpoint(
self.filename,
self.model,
self.optimizer,
self.best_acc,
epoch,
)
self.model.train()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def choose_network(self):
self.model = LeNet(
in_channels=self.in_channels, init_weights=True, num_classes=10
)
self.filename = "checkpoint/mnist_LeNet.pth.tar"
# self.model = VGG('VGG16', in_channels = self.in_channels)
# self.filename = 'checkpoint/mnist_VGG16.pth.tar'
# self.model = ResNet50(img_channel=1)
# self.filename = 'checkpoint/mnist_ResNet.pth.tar'
# self.model = GoogLeNet(img_channel=1)
# self.filename = 'checkpoint/mnist_GoogLeNet.pth.tar'
self.model = self.model.to(self.device)
def main(self):
if __name__ == "__main__":
self.choose_network()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=self.args.lr,
weight_decay=self.args.weight_decay,
momentum=self.args.momentum,
)
cudnn.benchmark = True
if self.args.resume:
self.model.eval()
(
self.model,
self.optimizer,
self.checkpoint,
self.start_epoch,
self.best_acc,
) = load_model(self.args, self.model, self.optimizer)
else:
load_model(self.args, self.model, self.optimizer)
self.train()
## Mnist
network = Train_MNIST()
Train_MNIST.main(network)

View File

@@ -0,0 +1,2 @@
from utils.mnist_data import mnist_data
from utils.utils import check_accuracy, save_checkpoint, visdom_plotting, load_model

View File

@@ -0,0 +1,94 @@
import numpy as np
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, SubsetRandomSampler
class mnist_data(object):
def __init__(
self,
shuffle,
transform_train,
transform_test,
num_workers=0,
create_validation_set=True,
batch_size=128,
validation_size=0.2,
random_seed=1,
):
self.shuffle = shuffle
self.validation_size = validation_size
self.transform_train = transform_train
self.transform_test = transform_test
self.random_seed = random_seed
self.create_validation_set = create_validation_set
self.batch_size = batch_size
self.num_workers = num_workers
def download_data(self):
mnist_trainset = datasets.MNIST(
root="./data", train=True, download=True, transform=self.transform_train
)
mnist_testset = datasets.MNIST(
root="./data", train=False, download=True, transform=self.transform_test
)
return mnist_trainset, mnist_testset
def create_validationset(self, mnist_trainset):
num_train = len(mnist_trainset)
indices = list(range(num_train))
split = int(self.validation_size * num_train)
if self.shuffle:
np.random.seed(self.random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
validation_sampler = SubsetRandomSampler(valid_idx)
loader_train = DataLoader(
dataset=mnist_trainset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.num_workers,
)
loader_validation = DataLoader(
dataset=mnist_trainset,
batch_size=self.batch_size,
sampler=validation_sampler,
num_workers=self.num_workers,
)
return loader_train, loader_validation
def main(self):
mnist_trainset, mnist_testset = self.download_data()
if self.create_validation_set:
loader_train, loader_validation = self.create_validationset(mnist_trainset)
loader_test = DataLoader(
dataset=mnist_testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return loader_train, loader_validation, loader_test
else:
loader_train = DataLoader(
dataset=mnist_trainset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
loader_test = DataLoader(
dataset=mnist_testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return loader_train, loader_test

View File

@@ -0,0 +1,130 @@
import torch
import visdom
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
def save_checkpoint(filename, model, optimizer, train_acc, epoch):
save_state = {
"state_dict": model.state_dict(),
"acc": train_acc,
"epoch": epoch + 1,
"optimizer": optimizer.state_dict(),
}
print()
print("Saving current parameters")
print("___________________________________________________________")
torch.save(save_state, filename)
def check_accuracy(loader, model):
if loader.dataset.train:
print("Checking accuracy on training or validation set")
else:
print("Checking accuracy on test set")
num_correct = 0
num_samples = 0
# model.eval() # set model to evaluation mode
with torch.no_grad():
for x, y in loader:
x = x.to(device=device, dtype=dtype) # move to device, e.g. GPU
y = y.to(device=device, dtype=torch.long)
scores = model(x)
_, preds = scores.max(1)
num_correct += (preds == y).sum()
num_samples += preds.size(0)
acc = (float(num_correct) / num_samples) * 100.0
print("Got %d / %d correct (%.2f)" % (num_correct, num_samples, acc))
return acc
def load_model(args, model, optimizer):
if args.resume:
model.eval()
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint["epoch"]
best_acc = checkpoint["acc"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
print(
"=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint["epoch"]
)
)
return model, optimizer, checkpoint, start_epoch, best_acc
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
print("No pretrained model. Starting from scratch!")
class visdom_plotting(object):
def __init__(self):
self.viz = visdom.Visdom()
self.cur_batch_win = None
self.cur_batch_win_opts = {
"title": "Epoch Loss Trace",
"xlabel": "Batch Number",
"ylabel": "Loss",
"width": 600,
"height": 400,
}
self.cur_validation_acc = None
self.cur_validation_acc_opts = {
"title": "Validation accuracy",
"xlabel": "Epochs",
"ylabel": "Validation Accuracy",
"width": 600,
"height": 400,
}
self.cur_training_acc = None
self.cur_training_acc_opts = {
"title": "Training accuracy",
"xlabel": "Epochs",
"ylabel": "Train Accuracy",
"width": 600,
"height": 400,
}
def create_plot(
self, loss_list, batch_list, validation_acc_list, epoch_list, training_acc_list
):
if self.viz.check_connection():
self.cur_batch_win = self.viz.line(
torch.FloatTensor(loss_list),
torch.FloatTensor(batch_list),
win=self.cur_batch_win,
name="current_batch_loss",
update=(None if self.cur_batch_win is None else "replace"),
opts=self.cur_batch_win_opts,
)
self.cur_validation_acc = self.viz.line(
torch.FloatTensor(validation_acc_list),
torch.FloatTensor(epoch_list),
win=self.cur_validation_acc,
name="current_validation_accuracy",
update=(None if self.cur_validation_acc is None else "replace"),
opts=self.cur_validation_acc_opts,
)
self.cur_training_acc = self.viz.line(
torch.FloatTensor(training_acc_list),
torch.FloatTensor(epoch_list),
win=self.cur_validation_acc,
name="current_training_accuracy",
update=(None if self.cur_training_acc is None else "replace"),
opts=self.cur_training_acc_opts,
)
#