mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
Initial commit
This commit is contained in:
12
ML/Pytorch/more_advanced/image_captioning/README.md
Normal file
12
ML/Pytorch/more_advanced/image_captioning/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
### Image Captioning
|
||||
|
||||
Download the dataset used: https://www.kaggle.com/dataset/e1cd22253a9b23b073794872bf565648ddbe4f17e7fa9e74766ad3707141adeb
|
||||
Then set images folder, captions.txt inside a folder Flickr8k.
|
||||
|
||||
train.py: For training the network
|
||||
|
||||
model.py: creating the encoderCNN, decoderRNN and hooking them togethor
|
||||
|
||||
get_loader.py: Loading the data, creating vocabulary
|
||||
|
||||
utils.py: Load model, save model, printing few test cases downloaded online
|
||||
142
ML/Pytorch/more_advanced/image_captioning/get_loader.py
Normal file
142
ML/Pytorch/more_advanced/image_captioning/get_loader.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import os # when loading file paths
|
||||
import pandas as pd # for lookup in annotation file
|
||||
import spacy # for tokenizer
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence # pad batch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from PIL import Image # Load img
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
# We want to convert text -> numerical values
|
||||
# 1. We need a Vocabulary mapping each word to a index
|
||||
# 2. We need to setup a Pytorch dataset to load the data
|
||||
# 3. Setup padding of every batch (all examples should be
|
||||
# of same seq_len and setup dataloader)
|
||||
# Note that loading the image is very easy compared to the text!
|
||||
|
||||
# Download with: python -m spacy download en
|
||||
spacy_eng = spacy.load("en")
|
||||
|
||||
|
||||
class Vocabulary:
|
||||
def __init__(self, freq_threshold):
|
||||
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
|
||||
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
|
||||
self.freq_threshold = freq_threshold
|
||||
|
||||
def __len__(self):
|
||||
return len(self.itos)
|
||||
|
||||
@staticmethod
|
||||
def tokenizer_eng(text):
|
||||
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
||||
|
||||
def build_vocabulary(self, sentence_list):
|
||||
frequencies = {}
|
||||
idx = 4
|
||||
|
||||
for sentence in sentence_list:
|
||||
for word in self.tokenizer_eng(sentence):
|
||||
if word not in frequencies:
|
||||
frequencies[word] = 1
|
||||
|
||||
else:
|
||||
frequencies[word] += 1
|
||||
|
||||
if frequencies[word] == self.freq_threshold:
|
||||
self.stoi[word] = idx
|
||||
self.itos[idx] = word
|
||||
idx += 1
|
||||
|
||||
def numericalize(self, text):
|
||||
tokenized_text = self.tokenizer_eng(text)
|
||||
|
||||
return [
|
||||
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
||||
for token in tokenized_text
|
||||
]
|
||||
|
||||
|
||||
class FlickrDataset(Dataset):
|
||||
def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
|
||||
self.root_dir = root_dir
|
||||
self.df = pd.read_csv(captions_file)
|
||||
self.transform = transform
|
||||
|
||||
# Get img, caption columns
|
||||
self.imgs = self.df["image"]
|
||||
self.captions = self.df["caption"]
|
||||
|
||||
# Initialize vocabulary and build vocab
|
||||
self.vocab = Vocabulary(freq_threshold)
|
||||
self.vocab.build_vocabulary(self.captions.tolist())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, index):
|
||||
caption = self.captions[index]
|
||||
img_id = self.imgs[index]
|
||||
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
numericalized_caption = [self.vocab.stoi["<SOS>"]]
|
||||
numericalized_caption += self.vocab.numericalize(caption)
|
||||
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
||||
|
||||
return img, torch.tensor(numericalized_caption)
|
||||
|
||||
|
||||
class MyCollate:
|
||||
def __init__(self, pad_idx):
|
||||
self.pad_idx = pad_idx
|
||||
|
||||
def __call__(self, batch):
|
||||
imgs = [item[0].unsqueeze(0) for item in batch]
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
targets = [item[1] for item in batch]
|
||||
targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
|
||||
|
||||
return imgs, targets
|
||||
|
||||
|
||||
def get_loader(
|
||||
root_folder,
|
||||
annotation_file,
|
||||
transform,
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
):
|
||||
dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
|
||||
|
||||
pad_idx = dataset.vocab.stoi["<PAD>"]
|
||||
|
||||
loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=MyCollate(pad_idx=pad_idx),
|
||||
)
|
||||
|
||||
return loader, dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
transform = transforms.Compose(
|
||||
[transforms.Resize((224, 224)), transforms.ToTensor(),]
|
||||
)
|
||||
|
||||
loader, dataset = get_loader(
|
||||
"flickr8k/images/", "flickr8k/captions.txt", transform=transform
|
||||
)
|
||||
|
||||
for idx, (imgs, captions) in enumerate(loader):
|
||||
print(imgs.shape)
|
||||
print(captions.shape)
|
||||
66
ML/Pytorch/more_advanced/image_captioning/model.py
Normal file
66
ML/Pytorch/more_advanced/image_captioning/model.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import statistics
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class EncoderCNN(nn.Module):
|
||||
def __init__(self, embed_size, train_CNN=False):
|
||||
super(EncoderCNN, self).__init__()
|
||||
self.train_CNN = train_CNN
|
||||
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
|
||||
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.times = []
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, images):
|
||||
features = self.inception(images)
|
||||
return self.dropout(self.relu(features))
|
||||
|
||||
|
||||
class DecoderRNN(nn.Module):
|
||||
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
||||
super(DecoderRNN, self).__init__()
|
||||
self.embed = nn.Embedding(vocab_size, embed_size)
|
||||
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
|
||||
self.linear = nn.Linear(hidden_size, vocab_size)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, features, captions):
|
||||
embeddings = self.dropout(self.embed(captions))
|
||||
embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
|
||||
hiddens, _ = self.lstm(embeddings)
|
||||
outputs = self.linear(hiddens)
|
||||
return outputs
|
||||
|
||||
|
||||
class CNNtoRNN(nn.Module):
|
||||
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
||||
super(CNNtoRNN, self).__init__()
|
||||
self.encoderCNN = EncoderCNN(embed_size)
|
||||
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
|
||||
|
||||
def forward(self, images, captions):
|
||||
features = self.encoderCNN(images)
|
||||
outputs = self.decoderRNN(features, captions)
|
||||
return outputs
|
||||
|
||||
def caption_image(self, image, vocabulary, max_length=50):
|
||||
result_caption = []
|
||||
|
||||
with torch.no_grad():
|
||||
x = self.encoderCNN(image).unsqueeze(0)
|
||||
states = None
|
||||
|
||||
for _ in range(max_length):
|
||||
hiddens, states = self.decoderRNN.lstm(x, states)
|
||||
output = self.decoderRNN.linear(hiddens.squeeze(0))
|
||||
predicted = output.argmax(1)
|
||||
result_caption.append(predicted.item())
|
||||
x = self.decoderRNN.embed(predicted).unsqueeze(0)
|
||||
|
||||
if vocabulary.itos[predicted.item()] == "<EOS>":
|
||||
break
|
||||
|
||||
return [vocabulary.itos[idx] for idx in result_caption]
|
||||
BIN
ML/Pytorch/more_advanced/image_captioning/test_examples/boat.png
Normal file
BIN
ML/Pytorch/more_advanced/image_captioning/test_examples/boat.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 369 KiB |
BIN
ML/Pytorch/more_advanced/image_captioning/test_examples/bus.png
Normal file
BIN
ML/Pytorch/more_advanced/image_captioning/test_examples/bus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 866 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 92 KiB |
BIN
ML/Pytorch/more_advanced/image_captioning/test_examples/dog.jpg
Normal file
BIN
ML/Pytorch/more_advanced/image_captioning/test_examples/dog.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 133 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 641 KiB |
96
ML/Pytorch/more_advanced/image_captioning/train.py
Normal file
96
ML/Pytorch/more_advanced/image_captioning/train.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import save_checkpoint, load_checkpoint, print_examples
|
||||
from get_loader import get_loader
|
||||
from model import CNNtoRNN
|
||||
|
||||
|
||||
def train():
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((356, 356)),
|
||||
transforms.RandomCrop((299, 299)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
|
||||
train_loader, dataset = get_loader(
|
||||
root_folder="flickr8k/images",
|
||||
annotation_file="flickr8k/captions.txt",
|
||||
transform=transform,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
load_model = False
|
||||
save_model = False
|
||||
train_CNN = False
|
||||
|
||||
# Hyperparameters
|
||||
embed_size = 256
|
||||
hidden_size = 256
|
||||
vocab_size = len(dataset.vocab)
|
||||
num_layers = 1
|
||||
learning_rate = 3e-4
|
||||
num_epochs = 100
|
||||
|
||||
# for tensorboard
|
||||
writer = SummaryWriter("runs/flickr")
|
||||
step = 0
|
||||
|
||||
# initialize model, loss etc
|
||||
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# Only finetune the CNN
|
||||
for name, param in model.encoderCNN.inception.named_parameters():
|
||||
if "fc.weight" in name or "fc.bias" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = train_CNN
|
||||
|
||||
if load_model:
|
||||
step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
|
||||
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# Uncomment the line below to see a couple of test cases
|
||||
# print_examples(model, device, dataset)
|
||||
|
||||
if save_model:
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"step": step,
|
||||
}
|
||||
save_checkpoint(checkpoint)
|
||||
|
||||
for idx, (imgs, captions) in tqdm(
|
||||
enumerate(train_loader), total=len(train_loader), leave=False
|
||||
):
|
||||
imgs = imgs.to(device)
|
||||
captions = captions.to(device)
|
||||
|
||||
outputs = model(imgs, captions[:-1])
|
||||
loss = criterion(
|
||||
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
|
||||
)
|
||||
|
||||
writer.add_scalar("Training loss", loss.item(), global_step=step)
|
||||
step += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
69
ML/Pytorch/more_advanced/image_captioning/utils.py
Normal file
69
ML/Pytorch/more_advanced/image_captioning/utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def print_examples(model, device, dataset):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((299, 299)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
|
||||
model.eval()
|
||||
test_img1 = transform(Image.open("test_examples/dog.jpg").convert("RGB")).unsqueeze(
|
||||
0
|
||||
)
|
||||
print("Example 1 CORRECT: Dog on a beach by the ocean")
|
||||
print(
|
||||
"Example 1 OUTPUT: "
|
||||
+ " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
|
||||
)
|
||||
test_img2 = transform(
|
||||
Image.open("test_examples/child.jpg").convert("RGB")
|
||||
).unsqueeze(0)
|
||||
print("Example 2 CORRECT: Child holding red frisbee outdoors")
|
||||
print(
|
||||
"Example 2 OUTPUT: "
|
||||
+ " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
|
||||
)
|
||||
test_img3 = transform(Image.open("test_examples/bus.png").convert("RGB")).unsqueeze(
|
||||
0
|
||||
)
|
||||
print("Example 3 CORRECT: Bus driving by parked cars")
|
||||
print(
|
||||
"Example 3 OUTPUT: "
|
||||
+ " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
|
||||
)
|
||||
test_img4 = transform(
|
||||
Image.open("test_examples/boat.png").convert("RGB")
|
||||
).unsqueeze(0)
|
||||
print("Example 4 CORRECT: A small boat in the ocean")
|
||||
print(
|
||||
"Example 4 OUTPUT: "
|
||||
+ " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
|
||||
)
|
||||
test_img5 = transform(
|
||||
Image.open("test_examples/horse.png").convert("RGB")
|
||||
).unsqueeze(0)
|
||||
print("Example 5 CORRECT: A cowboy riding a horse in the desert")
|
||||
print(
|
||||
"Example 5 OUTPUT: "
|
||||
+ " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
|
||||
)
|
||||
model.train()
|
||||
|
||||
|
||||
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint, model, optimizer):
|
||||
print("=> Loading checkpoint")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
step = checkpoint["step"]
|
||||
return step
|
||||
Reference in New Issue
Block a user