import torch import spacy from torchtext.data.metrics import bleu_score import sys def translate_sentence(model, sentence, german, english, device, max_length=50): # Load german tokenizer spacy_ger = spacy.load("de") # Create tokens using spacy and everything in lower case (which is what our vocab is) if type(sentence) == str: tokens = [token.text.lower() for token in spacy_ger(sentence)] else: tokens = [token.lower() for token in sentence] # Add and in beginning and end respectively tokens.insert(0, german.init_token) tokens.append(german.eos_token) # Go through each german token and convert to an index text_to_indices = [german.vocab.stoi[token] for token in tokens] # Convert to Tensor sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device) # Build encoder hidden, cell state with torch.no_grad(): outputs_encoder, hiddens, cells = model.encoder(sentence_tensor) outputs = [english.vocab.stoi[""]] for _ in range(max_length): previous_word = torch.LongTensor([outputs[-1]]).to(device) with torch.no_grad(): output, hiddens, cells = model.decoder( previous_word, outputs_encoder, hiddens, cells ) best_guess = output.argmax(1).item() outputs.append(best_guess) # Model predicts it's the end of the sentence if output.argmax(1).item() == english.vocab.stoi[""]: break translated_sentence = [english.vocab.itos[idx] for idx in outputs] # remove start token return translated_sentence[1:] def bleu(data, model, german, english, device): targets = [] outputs = [] for example in data: src = vars(example)["src"] trg = vars(example)["trg"] prediction = translate_sentence(model, src, german, english, device) prediction = prediction[:-1] # remove token targets.append([trg]) outputs.append(prediction) return bleu_score(outputs, targets) def save_checkpoint(state, filename="my_checkpoint.pth.tar"): print("=> Saving checkpoint") torch.save(state, filename) def load_checkpoint(checkpoint, model, optimizer): print("=> Loading checkpoint") model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"])