Loss of BERT mask language modeling and next sentence prediction

I train bert to do mask language modeling (MLM) of next sentence prediction (NSP) tasks. But I am confused about the loss function.

import torch
from torch import tensor
import torch.nn as nn

Let’s start with NSP.
Indeed, let’s suppose that I have three pairs of sentences (ie batch_size=3) and that for these three sentences the labels are the following (0 = noNext, 1=isNext) :

is_next = tensor([0, 1, 1]) # batch_size

Since in my dataset some sentences are too short for the NSP task, I also have a tensor that informs me whether this was done (1) or not (0): i.e. whether the entry was in the form [CLS] segment_A [SEP] segment_B [SEP] (1) or just [CLS] segment_A [SEP] (0).

is_next_weight = tensor([1, 1, 0]) # batch_size, no SNP for the last sentence

Let’s suppose that my model returns the following logits:

logits_clsf = tensor([[0.3797, 0.6203], [0.4363, 0.5637], [0.3797, 0.6203]]) # batch_size x 2

I have a loss function :

criterion2 = nn.CrossEntropyLoss()

At this level I have two ways to have my final loss.

  • either I calculate its value without excluding the unwanted entries, then I average the desired entries.
loss_clsf = criterion2(logits_clsf, is_next)
loss_clsf = (loss_clsf*is_next_weight.float()).mean()
>>> tensor(0.4516)
  • or I select only the desired entries and calculate the loss with them.
is_next = is_next [is_next_weight == 1]
logits_clsf = logits_clsf [is_next_weight == 1]
loss_clsf = criterion2(logits_clsf, is_next)
>>> tensor(0.7261)

I also have this concern at the MLM level. As the number of masked tokens depends on each sentence, I sometimes complete, with torch.nn.utils.rnn.pad_sequence, masked_ids, which contains here the ids of the masked tokens, so that they have the same length for the same batch (pytorch constraint).
Let us take an example with a batch_size of 3 and a vocabulary of size n_vocab=5 (it is just to make simple: 0 = pad_index, one ignores the ids of the special tokens like [UNK], [CLS] and [SEP]): 1 token of id [2] is masked in the first sentence of the batch, 2 tokens of ids [1, 2] are masked in the second and 3 in the last, of ids [1,3,4].
Then, max_len = 3 (number of masked tokens in the last one) and :

masked_ids = tensor([[2, 0, 0], [1, 2, 0], [1, 3, 4]]) # batch_size x max_len
masked_weights = tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) # batch_size x max_len

masked_weights is 1 for real masked tokens and 0 for padding tokens.

Let’s suppose that our model returns the following logits:

batch_size, max_len, n_vocab = 3, 3, 5
torch.manual_seed(0)
logits_lm = torch.rand(batch_size, max_len, n_vocab) # batch_size x max_len x n_vocab
>>> tensor([[[0.4963, 0.7682, 0.0885, 0.1320, 0.3074],
	         [0.6341, 0.4901, 0.8964, 0.4556, 0.6323],
	         [0.3489, 0.4017, 0.0223, 0.1689, 0.2939]],

	        [[0.5185, 0.6977, 0.8000, 0.1610, 0.2823],
	         [0.6816, 0.9152, 0.3971, 0.8742, 0.4194],
	         [0.5529, 0.9527, 0.0362, 0.1852, 0.3734]],

	        [[0.3051, 0.9320, 0.1759, 0.2698, 0.1507],
	         [0.0317, 0.2081, 0.9298, 0.7231, 0.7423],
	         [0.5263, 0.2437, 0.5846, 0.0332, 0.1387]]])

I also have a loss function:

criterion1 = nn.CrossEntropyLoss(reduction='none')

As with NSP, I have two choices:

  • calculate the loss before reducing it:
loss_lm = criterion1(logits_lm.transpose(1, 2), masked_ids)
loss_lm = (loss_lm*masked_weights.float()).mean()
>>> tensor(1.0664)
  • or reduce the entries before compute the loss.
masked_ids = masked_ids [masked_weights == 1]
logits_lm = logits_lm [masked_weights == 1]

loss_lm = criterion1(logits_lm, masked_ids)
loss_lm = loss_lm.mean()
>>> tensor(1.5996)

I would like to know which choices are better and why? Because they return different values.