Files

35 lines
1.0 KiB
Python
Raw Permalink Normal View History

import torch
import torchvision
from torch import nn
import pytorch_lightning as pl
from transformers import WhisperProcessor, WhisperTokenizer, WhisperFeatureExtractor
from transformers import WhisperForConditionalGeneration
class WhisperFinetuning(pl.LightningModule):
def __init__(self, lr, whisper_model="tiny"):
super().__init__()
self.lr = lr
self.model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{whisper_model}")
self.model.config.forced_decoder_ids = None
self.model.config.suppress_tokens = []
def training_step(self, batch, batch_idx):
encoder_input = batch[0]["input_features"]
decoder_labels = batch[1]["labels"]
out = self.model(
input_features=encoder_input,
labels=decoder_labels,
)
loss = out["loss"]
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if __name__ == "__main__":
pass