Isn’t it a simple solution to use the arg weight
to CrossEntropy or NLLLoss? AFAIK loss corresponding to padding steps will be set to zero if its corresponding weight is set to 0.
weight = torch.ones(vocab_size)
weight[pad_id] = 0.0
loss = nn.CrossEntropy(weight=weight)
or am I misunderstanding something?