Isn’t it a simple solution to use the arg weight to CrossEntropy or NLLLoss? AFAIK loss corresponding to padding steps will be set to zero if its corresponding weight is set to 0.
weight = torch.ones(vocab_size)
weight[pad_id] = 0.0
loss = nn.CrossEntropy(weight=weight)
or am I misunderstanding something?