mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
Initial commit
This commit is contained in:
75
ML/Projects/DeepSort/utils.py
Normal file
75
ML/Projects/DeepSort/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
|
||||
|
||||
def ask_user():
|
||||
print("Write your array as a list [i,j,k..] with arbitrary positive numbers")
|
||||
array = input("Input q if you want to quit \n")
|
||||
return array
|
||||
|
||||
|
||||
def sort_array(encoder, decoder, device, arr=None):
|
||||
"""
|
||||
A very simple example of use of the model
|
||||
Input: encoder nn.Module
|
||||
decoder nn.Module
|
||||
device
|
||||
array to sort (optional)
|
||||
"""
|
||||
|
||||
if arr is None:
|
||||
arr = ask_user()
|
||||
|
||||
with torch.no_grad():
|
||||
while arr != "q":
|
||||
# Avoid numerical errors by rounding to max_len
|
||||
arr = eval(arr)
|
||||
lengths = [
|
||||
len(str(elem).split(".")[1]) if len(str(elem).split(".")) > 1 else 0
|
||||
for elem in arr
|
||||
]
|
||||
max_len = max(lengths)
|
||||
source = torch.tensor(arr, dtype=torch.float).to(device).unsqueeze(1)
|
||||
batch_size = source.shape[1]
|
||||
target_len = source.shape[0] + 1
|
||||
|
||||
outputs = torch.zeros(target_len, batch_size, target_len - 1).to(device)
|
||||
encoder_states, hidden, cell = encoder(source)
|
||||
|
||||
# First input will be <SOS> token
|
||||
x = torch.tensor([-1], dtype=torch.float).to(device)
|
||||
predictions = torch.zeros((target_len)).to(device)
|
||||
|
||||
for t in range(1, target_len):
|
||||
# At every time step use encoder_states and update hidden, cell
|
||||
attention, energy, hidden, cell = decoder(
|
||||
x, encoder_states, hidden, cell
|
||||
)
|
||||
|
||||
# Store prediction for current time step
|
||||
outputs[t] = energy.permute(1, 0)
|
||||
|
||||
# Get the best word the Decoder predicted (index in the vocabulary)
|
||||
best_guess = attention.argmax(0)
|
||||
predictions[t] = best_guess.item()
|
||||
x = torch.tensor([best_guess.item()], dtype=torch.float).to(device)
|
||||
|
||||
output = [
|
||||
round(source[predictions[1:].long()][i, :].item(), max_len)
|
||||
for i in range(source.shape[0])
|
||||
]
|
||||
|
||||
print(f"Here's the result: {output}")
|
||||
arr = ask_user()
|
||||
|
||||
|
||||
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
||||
print("=> Saving checkpoint")
|
||||
torch.save(state, filename)
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint, model, optimizer): # , steps):
|
||||
print("=> Loading checkpoint")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
# steps = checkpoint['steps']
|
||||
# return steps
|
||||
Reference in New Issue
Block a user