How do I format my transformer loss computation?

I built a transformer from scratch for machine translation, where the src and trg dimensions are both batch_size x seq_len, and the output prediction dimensions are batch_size x seq_len x trg_vocab_size. I am using torch.nn.CrossEntropyLoss(), but I’m not sure how to format the dimensions correctly before inputting into the loss function. Does anyone know what the exact dimensions of ‘trg’ and ‘out’ should be and have tips/suggestions on how to make those changes?

# train loop 
optimizer=torch.optim.Adam(params=model.parameters(),lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss() 
for epoch in range(0,EPOCHS+1):

    for i, (src,trg) in enumerate(train_data):

        src = torch.Tensor(src).to(DEVICE).long()
        trg = torch.Tensor(src).to(DEVICE).long()

        # forward pass 
        out = model(src,trg)
        print('pred size: ',out.size()) # size = batch_size x seq_len x target_vocab_size
        print('trg size: ',trg.size()) # size = batc_size x seq_len

        # # compute loss 
        # loss = loss_fn(out, trg)
        # print(loss)
        # optimizer.zero_grad()
        # loss.backward()

If I understand correctly you want the loss between ground truth (batch_size*seq, 1) and predictions (batch_size*seq, vocab_size)

You can do something like I guess

loss = loss_fn(out.view(-1, trg_vocab_size), trg.view(-1))

I tried the code you suggested, the loss function is decreasing slightly but does not go below 7.2. I have a mini data set for debugging purposes so I wouldn’t expect it to learn completely. Why don’t we want out and trg to be the exact same dimensions? In the code you suggested, the last dimension is still different.

Ideally we should one-hot encode the ground-truth to get the shape of (batch_size, seq_len, trg_vocab_size) where the entry will 1 for the index and 0 otherwise. But this is not needed as nn.CrossEntropyLoss can compute the entropy loss even when ground-truth is not one-hot encoded. You can read the docs.

Can I know in ouput of model out = model(src, trg) did you apply softwax or not after the last linear layer ?

The reason I ask this is because in nn.CrossEntropyLoss we should give logits (before softmax) and not probabilities as the predictions ?

I did a torch.nn.functional.softmax() inside the last layer of the decoder. Now I removed that, and I get a loss curve looking this:

So it seems to be working, but I will read the docs. Also, if I add each sentence with zeros to the MAX_LENGTH of my dataset, is there anything I should account for with the loss function with regards to this?

For loss calculation you will need predictions and ground-truth label (index of word in vocabulary)

Suppose you pad the sentence (both source and target language in your case), at the time of calculation of loss it is recommended (it is your decision if want to ignore or not) to ignore the positions of pad.

You can manually do the above calculation like first extracting the predictions and ground truth for non-pad positions and then caclulate the loss.

Pytorch provides easy way to do this by having the parameter ignore_index (by default it is -100). If your ground-truth has -100 then nn.CrossEntropyLoss will ignore these values. For example

# here I am using different ignore_index
loss_fn = nn.CrossEntropyLoss(ignore_index=-999)
loss = loss_fn(out.view(-1, trg_vocab_size), trg.view(-1))

You can choose different ignore_index but you have to consitent with this in your preprocessing scripts as well.

Let me know if you want more details on this.

Note : This is off topic but I have some questions

  • what tokenizer are you using ?
  • Are you also building the tokenizer from scratch ?
  • For both input language and target language, you have same tokenizer or different tokenizer ?

Thanks.

1 Like

How would I ignore the index if the padding value is ‘0’? Would I say ignore_index=0?

  • I am using word_tokenizer from nltk library
  • I’m using the same tokenizer for both the src and tgt languages, with slightly different lines of code for adjustments.

I sent you a DM with more details!