mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import statistics
|
||
|
|
import torchvision.models as models
|
||
|
|
|
||
|
|
|
||
|
|
class EncoderCNN(nn.Module):
|
||
|
|
def __init__(self, embed_size, train_CNN=False):
|
||
|
|
super(EncoderCNN, self).__init__()
|
||
|
|
self.train_CNN = train_CNN
|
||
|
|
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
|
||
|
|
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
|
||
|
|
self.relu = nn.ReLU()
|
||
|
|
self.times = []
|
||
|
|
self.dropout = nn.Dropout(0.5)
|
||
|
|
|
||
|
|
def forward(self, images):
|
||
|
|
features = self.inception(images)
|
||
|
|
return self.dropout(self.relu(features))
|
||
|
|
|
||
|
|
|
||
|
|
class DecoderRNN(nn.Module):
|
||
|
|
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
||
|
|
super(DecoderRNN, self).__init__()
|
||
|
|
self.embed = nn.Embedding(vocab_size, embed_size)
|
||
|
|
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
|
||
|
|
self.linear = nn.Linear(hidden_size, vocab_size)
|
||
|
|
self.dropout = nn.Dropout(0.5)
|
||
|
|
|
||
|
|
def forward(self, features, captions):
|
||
|
|
embeddings = self.dropout(self.embed(captions))
|
||
|
|
embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
|
||
|
|
hiddens, _ = self.lstm(embeddings)
|
||
|
|
outputs = self.linear(hiddens)
|
||
|
|
return outputs
|
||
|
|
|
||
|
|
|
||
|
|
class CNNtoRNN(nn.Module):
|
||
|
|
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
||
|
|
super(CNNtoRNN, self).__init__()
|
||
|
|
self.encoderCNN = EncoderCNN(embed_size)
|
||
|
|
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
|
||
|
|
|
||
|
|
def forward(self, images, captions):
|
||
|
|
features = self.encoderCNN(images)
|
||
|
|
outputs = self.decoderRNN(features, captions)
|
||
|
|
return outputs
|
||
|
|
|
||
|
|
def caption_image(self, image, vocabulary, max_length=50):
|
||
|
|
result_caption = []
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
x = self.encoderCNN(image).unsqueeze(0)
|
||
|
|
states = None
|
||
|
|
|
||
|
|
for _ in range(max_length):
|
||
|
|
hiddens, states = self.decoderRNN.lstm(x, states)
|
||
|
|
output = self.decoderRNN.linear(hiddens.squeeze(0))
|
||
|
|
predicted = output.argmax(1)
|
||
|
|
result_caption.append(predicted.item())
|
||
|
|
x = self.decoderRNN.embed(predicted).unsqueeze(0)
|
||
|
|
|
||
|
|
if vocabulary.itos[predicted.item()] == "<EOS>":
|
||
|
|
break
|
||
|
|
|
||
|
|
return [vocabulary.itos[idx] for idx in result_caption]
|