So I think I may have built an attention transformer

I pretty faithfully followed the code found at Machine-Learning-Collection/transformer_from_scratch.py at master · aladdinpersson/Machine-Learning-Collection · GitHub

I even wrote a little training function for it.

import torch.optim as optim

device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

src= torch.tensor([[1,2,3,4,5,0],[3,4,5,6,7,8]]).to(device)
trg=torch.tensor([[6,7,8,9,10,11],[8,9,10,11,12,13]]).to(device)

model=Transformer(src_vocab_size=14, trg_vocab_size=14, src_pad_idx=0, trg_pad_idx=0).to(device)

opt=torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn=nn.functional.cross_entropy

def train_model(epochs=501):
    total_loss=0
    train_loss_list, validation_loss_list = [], []
    for e in range(epochs):
        model.train()
        pred=model(src, trg)
        if e%50==0:
            print("-"*25, f"Epoch {e + 1}","-"*25)
            print(pred.shape)
            print(torch.argmax(pred,dim=2))
            print(trg)
        loss=loss_fn(pred, torch.nn.functional.one_hot(trg, num_classes=14).type(torch.FloatTensor))
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        total_loss += loss.detach().item()/(e+1)
        train_loss_list += [total_loss]
        
        if e%50==0: print(f"Training loss: {total_loss:.4f}")
        print('...')
        
    return train_loss_list
print(model)
train_loss_list = train_model()

After 400 or so epochs, it actually predicts the right target sequence given the source!

What I am asking you is,

  1. Did I actually implement this training function correctly? Is my model learning or do I have data leakage? I’m pretty sure my loss is being calculated incorrectly, so if you have any tips on that too, that’d be great.

  2. How do I test this function on data it hasn’t seen? For example, the sequence [2,3,4,5,6,7] should predict [7,8,9,10,11,12], but I don’t know how to run the model with only inputs (it calls for src and trg).

Your guidance is really appreciated. I’m looking to take this transformer and generate song lyrics with it.