mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
Initial commit
This commit is contained in:
84
ML/Pytorch/more_advanced/Seq2Seq/utils.py
Normal file
84
ML/Pytorch/more_advanced/Seq2Seq/utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import spacy
|
||||
from torchtext.data.metrics import bleu_score
|
||||
import sys
|
||||
|
||||
|
||||
def translate_sentence(model, sentence, german, english, device, max_length=50):
|
||||
# print(sentence)
|
||||
|
||||
# sys.exit()
|
||||
|
||||
# Load german tokenizer
|
||||
spacy_ger = spacy.load("de")
|
||||
|
||||
# Create tokens using spacy and everything in lower case (which is what our vocab is)
|
||||
if type(sentence) == str:
|
||||
tokens = [token.text.lower() for token in spacy_ger(sentence)]
|
||||
else:
|
||||
tokens = [token.lower() for token in sentence]
|
||||
|
||||
# print(tokens)
|
||||
|
||||
# sys.exit()
|
||||
# Add <SOS> and <EOS> in beginning and end respectively
|
||||
tokens.insert(0, german.init_token)
|
||||
tokens.append(german.eos_token)
|
||||
|
||||
# Go through each german token and convert to an index
|
||||
text_to_indices = [german.vocab.stoi[token] for token in tokens]
|
||||
|
||||
# Convert to Tensor
|
||||
sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
|
||||
|
||||
# Build encoder hidden, cell state
|
||||
with torch.no_grad():
|
||||
hidden, cell = model.encoder(sentence_tensor)
|
||||
|
||||
outputs = [english.vocab.stoi["<sos>"]]
|
||||
|
||||
for _ in range(max_length):
|
||||
previous_word = torch.LongTensor([outputs[-1]]).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output, hidden, cell = model.decoder(previous_word, hidden, cell)
|
||||
best_guess = output.argmax(1).item()
|
||||
|
||||
outputs.append(best_guess)
|
||||
|
||||
# Model predicts it's the end of the sentence
|
||||
if output.argmax(1).item() == english.vocab.stoi["<eos>"]:
|
||||
break
|
||||
|
||||
translated_sentence = [english.vocab.itos[idx] for idx in outputs]
|
||||
|
||||
# remove start token
|
||||
return translated_sentence[1:]
|
||||
|
||||
|
||||
def bleu(data, model, german, english, device):
|
||||
targets = []
|
||||
outputs = []
|
||||
|
||||
for example in data:
|
||||
src = vars(example)["src"]
|
||||
trg = vars(example)["trg"]
|
||||
|
||||
prediction = translate_sentence(model, src, german, english, device)
|
||||
prediction = prediction[:-1] # remove <eos> token
|
||||
|
||||
targets.append([trg])
|
||||
outputs.append(prediction)
|
||||
|
||||
return bleu_score(outputs, targets)
|
||||
|
||||
|
||||
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint, model, optimizer):
|
||||
print("=> Loading checkpoint")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
Reference in New Issue
Block a user