huggingface update

This commit is contained in:
Aladdin Persson
2023-03-18 09:51:16 +01:00
parent 94f6c024fe
commit e4659fe56a
14 changed files with 718 additions and 9968 deletions

View File

@@ -6,7 +6,7 @@ from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as pl
import pytorch_lightning as pl
class NN(pl.LightningModule):
@@ -23,28 +23,28 @@ class NN(pl.LightningModule):
def training_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
self.log('train_loss', loss)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
self.log('val_loss', loss)
self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
self.log('test_loss', loss)
self.log("test_loss", loss)
return loss
def _common_step(self, batch, batch_idx):
x, y = batch
x, y = batch
x = x.reshape(x.size(0), -1)
scores = self.forward(x)
loss = self.loss_fn(scores, y)
return loss, scores, y
def predict_step(self, batch, batch_idx):
x, y = batch
x, y = batch
x = x.reshape(x.size(0), -1)
scores = self.forward(x)
preds = torch.argmax(scores, dim=1)
@@ -53,6 +53,7 @@ class NN(pl.LightningModule):
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001)
# Set device cuda for GPU if it's available otherwise run on the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -82,7 +83,13 @@ model = NN(input_size=input_size, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
trainer = pl.Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16)
trainer = pl.Trainer(
accelerator="gpu",
devices=1,
min_epochs=1,
max_epochs=3,
precision=16,
)
trainer.fit(model, train_loader, val_loader)
trainer.validate(model, val_loader)
trainer.test(model, test_loader)