Gradient accumulation loss compute

Hello, everyone!
Suppose we have data [b,s,dim], I recently noticed that CrossEntropyLoss is (1) computed the average on all tokens (b * s) in a batch instead of (2) computing on each sentence and then compute the average.
Here is the code to compute loss for hugging_face transformers BertLMHeadModel

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

I know that (1) and (2) have no difference in this situation. But when we apply gradient accumulation, I think the situation is different.

Suppose I have a batch_size 4, and the lengths of 4 sentences are 100,200,300,400.
With batch_size 4, the loss is computed on the average of total 1000 tokens.

But with batch_size = 1 and gradient accumulation = 4, I think the loss is different. We first compute the loss on each sentence separately and then compute the average, which means for the sentence of 100 tokens, we compute the loss average of 100 tokens and then divide by 4 and add it to total loss the same for others 3 sentences, and I think the loss computed this way is different from that computed with batch_size =4.

Did I misunderstand something?

If you want to reduce the loss in a custom way, use reduction="none" while initializing the criterion, and apply your custom reduction afterwards.

I noticed that most of the gradient accumulation implementation did not take this into consideration,

for example , the huggingface transformers trainer

    if self.args.gradient_accumulation_steps > 1:
        loss = loss / self.args.gradient_accumulation_steps

where loss is a normal CE loss from micro batch.

Dose this means these different behavior on computing loss will not influence the performance?