Ignore padding area in loss computation

I am working on small texts doing Sequence Labelling. Specifically, I use a BERT model from the huggingface library (BertModel in particular), and I tokenize every text with the library’s tokenizer to feed the model. Since the texts are small, I have specified that the sequence length that the tokenizer produces is 256. My labels are binary (1 and 0) and every sequence element (BERT input token) is assigned a label.

For the loss computation I use Binary Cross Entropy (BCEWithLogitsLos) but the function considers also the padding tokens to compute the loss which also affects back propagation.

I want BCEWithLogitsLos to compute the loss only on the tokens of the text and not also on the padding tokens. Which is the best way to achieve that?

I think you could try to use the raw loss output (via reduction='none'), set the unwanted loss entries to zero, reduce the loss, and calculate the gradients via loss.backward(). Unsure, if there is a better way to mask the loss.

2 Likes

@ptrblck thank you for your response.

You are right, this is a clean way to implement it.
I tried it and even though the loss was different, the model metrics did not change.
I suppose that when you said to reduce the loss, because reduction='none', you meat to use torch.mean() after turning into zero the loss of the padded tokens.

I just mentioned that because the resulting tensor grad_fn is Mean, although the backward computation will have all steps, including BCEWithLogits.

I also cannot think of a better way to implement this.

Regards.

Yes, that’s what I had in mind. The backward call should still work as intended and internally the mean reduction would do the same. In your manual approach the grad_fn would point to the mean operation, which shouldn’t be a concern.

That is also what I thought. Thanks again for your help.

Regards.

@ptrblck, I have a follow-up question to that.
Using torch.mean() implies that the elements that were zeroed out are also taken into account when calculating the average and thus affect the backpropagation. I’m wondering if it makes more sense to divide by the count of non-zeroed out elements.

pad = 2

tags = torch.tensor([0,1,1,0,1,2])

# for this example, let's pretend this is our loss tensor that we got from the unreduced BCEWithLogitsLoss
loss = torch.tensor([0.001, -0.3, 0.9, 0.7,0.6, 0.8])  

loss_mask = tags != pad
# loss_mask tensor([ True,  True,  True,  True,  True, False])

loss_masked = loss.where(loss_mask, torch.tensor(0.0))
# loss_masked tensor([ 0.0010, -0.3000,  0.9000,  0.7000,  0.6000,  0.0000])

loss_masked.mean()  # tensor(0.3168)

loss_masked.sum() / loss_mask.sum()  # tensor(0.3802)
2 Likes

Yes sounds like an interesting idea! In case you are using this approach and have a comparison to the mean approach it would be interesting to see your findings and if you’ve seen any advantage using one or the other during training.

2 Likes

FYI I just compared the two approaches on an NMT problem @izaskr was saying:

# option A
loss_masked.mean()  # tensor(0.3168)

# option B
loss_masked.sum() / loss_mask.sum()  # tensor(0.3802)

Option B works much better. Monotonically decreasing loss, better generalization.

3 Likes

An alternative way of option B

loss_masked = torch.masked_select(loss, loss_mask)
loss_masked.mean()

@ptrblck - can this also be done by specifying ignore_index==pad_token_id in the cross entropy function, vs having to do the zeroing-out manually? torch.nn.functional.cross_entropy — PyTorch 2.1 documentation

Yes, this should work if you are using nn.CrossEntropyLoss with class indices as the target. Note that the author used nn.BCEWithLogitsLoss, which might need a manual masking as “soft” targets are expected.

1 Like