Initial commit

This commit is contained in:
Aladdin Persson
2021-01-30 21:49:15 +01:00
commit 65b8c80495
432 changed files with 1290844 additions and 0 deletions

View File

@@ -0,0 +1,264 @@
import argparse
import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, SubsetRandomSampler
from networks.import_all_networks import *
from utils.import_utils import *
class Train_MNIST(object):
def __init__(self):
self.best_acc = 0
self.in_channels = 1 # 1 because MNIST is grayscale
self.dataset = mnist_data # Class that is imported from utils that imports data
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float32
self.args = self.prepare_args()
self.transform_train, self.transform_test = self.prepare_transformations()
if self.args.create_validationset:
(
self.loader_train,
self.loader_validation,
self.loader_test,
) = self.prepare_data()
self.data_check_acc = self.loader_validation
else:
self.loader_train, self.loader_test = self.prepare_data()
self.data_check_acc = self.loader_train
def prepare_args(self):
parser = argparse.ArgumentParser(description="PyTorch MNIST")
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--lr",
default=0.001,
type=float,
metavar="LR",
help="initial learning rate",
)
parser.add_argument(
"--weight-decay",
default=1e-5,
type=float,
metavar="R",
help="L2 regularization lambda",
)
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="SGD with momentum"
)
parser.add_argument(
"--epochs",
type=int,
default=100,
metavar="N",
help="number of epochs to train (default: 100)",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--log-interval",
type=int,
default=240,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--number-workers",
type=int,
default=0,
metavar="S",
help="number of workers (default: 0)",
)
parser.add_argument(
"--init-padding",
type=int,
default=2,
metavar="S",
help=" If use initial padding or not. (default: 2 because mnist 28x28 to make 32x32)",
)
parser.add_argument(
"--create-validationset",
action="store_true",
default=False,
help="If you want to use a validation set (default: False). Default size = 10%",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="If you want to save this model(default: False).",
)
args = parser.parse_args()
return args
def prepare_transformations(self):
transform_train = transforms.Compose(
[
transforms.Pad(self.args.init_padding),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
transform_test = transforms.Compose(
[
transforms.Pad(self.args.init_padding),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
return transform_train, transform_test
def prepare_data(self, shuffle=True):
data = self.dataset(
shuffle,
self.transform_train,
self.transform_test,
self.args.number_workers,
self.args.create_validationset,
self.args.batch_size,
validation_size=0.1,
random_seed=self.args.seed,
)
if self.args.create_validationset:
loader_train, loader_validation, loader_test = data.main()
return loader_train, loader_validation, loader_test
else:
loader_train, loader_test = data.main()
return loader_train, loader_test
def train(self):
criterion = nn.CrossEntropyLoss()
iter = 0
# vis_plotting = visdom_plotting()
loss_list, batch_list, epoch_list, validation_acc_list, training_acc_list = (
[],
[],
[0],
[0],
[0],
)
for epoch in range(self.args.epochs):
for batch_idx, (x, y) in enumerate(self.loader_train):
self.model.train()
x = x.to(device=self.device, dtype=self.dtype)
y = y.to(device=self.device, dtype=torch.long)
scores = self.model(x)
loss = criterion(scores, y)
loss_list.append(loss.item())
batch_list.append(iter + 1)
iter += 1
if batch_idx % self.args.log_interval == 0:
print(f"Batch {batch_idx}, epoch {epoch}, loss = {loss.item()}")
print()
self.model.eval()
train_acc = check_accuracy(self.data_check_acc, self.model)
# validation_acc = self.check_accuracy(self.data_check_acc)
validation_acc = 0
validation_acc_list.append(validation_acc)
training_acc_list.append(train_acc)
epoch_list.append(epoch + 0.5)
print()
print()
# call to plot in visdom
# vis_plotting.create_plot(loss_list, batch_list, validation_acc_list, epoch_list, training_acc_list)
# save checkpoint
if train_acc > self.best_acc and self.args.save_model:
self.best_acc = train_acc
save_checkpoint(
self.filename,
self.model,
self.optimizer,
self.best_acc,
epoch,
)
self.model.train()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def choose_network(self):
self.model = LeNet(
in_channels=self.in_channels, init_weights=True, num_classes=10
)
self.filename = "checkpoint/mnist_LeNet.pth.tar"
# self.model = VGG('VGG16', in_channels = self.in_channels)
# self.filename = 'checkpoint/mnist_VGG16.pth.tar'
# self.model = ResNet50(img_channel=1)
# self.filename = 'checkpoint/mnist_ResNet.pth.tar'
# self.model = GoogLeNet(img_channel=1)
# self.filename = 'checkpoint/mnist_GoogLeNet.pth.tar'
self.model = self.model.to(self.device)
def main(self):
if __name__ == "__main__":
self.choose_network()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=self.args.lr,
weight_decay=self.args.weight_decay,
momentum=self.args.momentum,
)
cudnn.benchmark = True
if self.args.resume:
self.model.eval()
(
self.model,
self.optimizer,
self.checkpoint,
self.start_epoch,
self.best_acc,
) = load_model(self.args, self.model, self.optimizer)
else:
load_model(self.args, self.model, self.optimizer)
self.train()
## Mnist
network = Train_MNIST()
Train_MNIST.main(network)