Files
Machine-Learning-Collection/ML/Pytorch/more_advanced/image_captioning/model.py
Aladdin Persson 65b8c80495 Initial commit
2021-01-30 21:49:15 +01:00

67 lines
2.4 KiB
Python

import torch
import torch.nn as nn
import statistics
import torchvision.models as models
class EncoderCNN(nn.Module):
def __init__(self, embed_size, train_CNN=False):
super(EncoderCNN, self).__init__()
self.train_CNN = train_CNN
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
self.relu = nn.ReLU()
self.times = []
self.dropout = nn.Dropout(0.5)
def forward(self, images):
features = self.inception(images)
return self.dropout(self.relu(features))
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.5)
def forward(self, features, captions):
embeddings = self.dropout(self.embed(captions))
embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
hiddens, _ = self.lstm(embeddings)
outputs = self.linear(hiddens)
return outputs
class CNNtoRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
super(CNNtoRNN, self).__init__()
self.encoderCNN = EncoderCNN(embed_size)
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
def forward(self, images, captions):
features = self.encoderCNN(images)
outputs = self.decoderRNN(features, captions)
return outputs
def caption_image(self, image, vocabulary, max_length=50):
result_caption = []
with torch.no_grad():
x = self.encoderCNN(image).unsqueeze(0)
states = None
for _ in range(max_length):
hiddens, states = self.decoderRNN.lstm(x, states)
output = self.decoderRNN.linear(hiddens.squeeze(0))
predicted = output.argmax(1)
result_caption.append(predicted.item())
x = self.decoderRNN.embed(predicted).unsqueeze(0)
if vocabulary.itos[predicted.item()] == "<EOS>":
break
return [vocabulary.itos[idx] for idx in result_caption]