I train bert to do mask language modeling (MLM) of next sentence prediction (NSP) tasks. But I am confused about the loss function.
import torch
from torch import tensor
import torch.nn as nn
Let’s start with NSP.
Indeed, let’s suppose that I have three pairs of sentences (ie batch_size=3) and that for these three sentences the labels are the following (0 = noNext, 1=isNext) :
is_next = tensor([0, 1, 1]) # batch_size
Since in my dataset some sentences are too short for the NSP task, I also have a tensor that informs me whether this was done (1) or not (0): i.e. whether the entry was in the form [CLS] segment_A [SEP] segment_B [SEP]
(1) or just [CLS] segment_A [SEP]
(0).
is_next_weight = tensor([1, 1, 0]) # batch_size, no SNP for the last sentence
Let’s suppose that my model returns the following logits:
logits_clsf = tensor([[0.3797, 0.6203], [0.4363, 0.5637], [0.3797, 0.6203]]) # batch_size x 2
I have a loss function :
criterion2 = nn.CrossEntropyLoss()
At this level I have two ways to have my final loss.
- either I calculate its value without excluding the unwanted entries, then I average the desired entries.
loss_clsf = criterion2(logits_clsf, is_next)
loss_clsf = (loss_clsf*is_next_weight.float()).mean()
>>> tensor(0.4516)
- or I select only the desired entries and calculate the loss with them.
is_next = is_next [is_next_weight == 1]
logits_clsf = logits_clsf [is_next_weight == 1]
loss_clsf = criterion2(logits_clsf, is_next)
>>> tensor(0.7261)
I also have this concern at the MLM level. As the number of masked tokens depends on each sentence, I sometimes complete, with torch.nn.utils.rnn.pad_sequence, masked_ids, which contains here the ids of the masked tokens, so that they have the same length for the same batch (pytorch constraint).
Let us take an example with a batch_size of 3 and a vocabulary of size n_vocab=5
(it is just to make simple: 0 = pad_index, one ignores the ids of the special tokens like [UNK]
, [CLS]
and [SEP]
): 1 token of id [2]
is masked in the first sentence of the batch, 2 tokens of ids [1, 2]
are masked in the second and 3 in the last, of ids [1,3,4]
.
Then, max_len = 3
(number of masked tokens in the last one) and :
masked_ids = tensor([[2, 0, 0], [1, 2, 0], [1, 3, 4]]) # batch_size x max_len
masked_weights = tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) # batch_size x max_len
masked_weights is 1 for real masked tokens and 0 for padding tokens.
Let’s suppose that our model returns the following logits:
batch_size, max_len, n_vocab = 3, 3, 5
torch.manual_seed(0)
logits_lm = torch.rand(batch_size, max_len, n_vocab) # batch_size x max_len x n_vocab
>>> tensor([[[0.4963, 0.7682, 0.0885, 0.1320, 0.3074],
[0.6341, 0.4901, 0.8964, 0.4556, 0.6323],
[0.3489, 0.4017, 0.0223, 0.1689, 0.2939]],
[[0.5185, 0.6977, 0.8000, 0.1610, 0.2823],
[0.6816, 0.9152, 0.3971, 0.8742, 0.4194],
[0.5529, 0.9527, 0.0362, 0.1852, 0.3734]],
[[0.3051, 0.9320, 0.1759, 0.2698, 0.1507],
[0.0317, 0.2081, 0.9298, 0.7231, 0.7423],
[0.5263, 0.2437, 0.5846, 0.0332, 0.1387]]])
I also have a loss function:
criterion1 = nn.CrossEntropyLoss(reduction='none')
As with NSP, I have two choices:
- calculate the loss before reducing it:
loss_lm = criterion1(logits_lm.transpose(1, 2), masked_ids)
loss_lm = (loss_lm*masked_weights.float()).mean()
>>> tensor(1.0664)
- or reduce the entries before compute the loss.
masked_ids = masked_ids [masked_weights == 1]
logits_lm = logits_lm [masked_weights == 1]
loss_lm = criterion1(logits_lm, masked_ids)
loss_lm = loss_lm.mean()
>>> tensor(1.5996)
I would like to know which choices are better and why? Because they return different values.