Chatbot Tutorial on Pytorch documentation, gathering indices that are invalid

(Tony Nguyen) #1

Hi,
From this tutorial here: https://pytorch.org/tutorials/beginner/chatbot_tutorial.html,
at the part where we calculate the loss using:

def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

Why are we using target.view(-1, 1) as the index to gather from? target contains word indices, if I’m not mistaken, which, given a large corpus can be very huge, and thus won’t be valid indices for the function to use to select from.

Also, we do we only gather from the first dimension?

#2

inp will probably contain the prediction probabilities as [batch_size, nb_words].
If that’s the case, the gather operation will get all predictions for the current target word index in each sample:

batch_size = 5
nb_words = 10

inp = torch.randn(batch_size, nb_words)
target = torch.randint(0, nb_words, (batch_size,))
torch.gather(inp, 1, target.view(-1, 1))
(Tony Nguyen) #3

Ah, I see, so this function requires the resulting probability to be of size (batch_size, number of words).

I’ve been reading articles about implementing a chatbot using the Transformer model from Google. I was borrowing the training loop here. However, the Transformer model requires the user to input an output size or as the paper from Google calls it, d_model. So currently, I am getting decoder outputs of size [batch_size, d_model]. How would I go about calculating the loss for this, do you have any pointer? Thanks.