Why does CrossEntropyLoss consume so many memory in Transformer

I’m trying to train a Transformer here with limited resource.

In general way, the loss function is:

lang_prediction_scores = self.cls(feats)

masked_lm_loss = CrossEntropyLoss(ignore_index=-1)(
                lang_prediction_scores.view(-1, self.config.vocab_size),
                masked_lm_labels.view(-1)
            )

Where ignore_index ignores all label with -1.

But I find this cost many GPU memory, whitch is unacceptable given my limited GPU memory. For example, when the length of tokens to be computed in loss function is 36, the training is fine. But when up to 72, It will consume 6GB more GPU meomory more or less. The number of input tokens is always 72, it’s just the tokens in loss fuction changed here.

I find a solution for this. Since transformer doesn’t really care for order, I move all tokens to be predicted in together. Then I just need predict much less tokens. Like this:

lang_prediction_scores = self.cls(feats[:, TOKENS_TO_BE_PREIDCTED:])

masked_lm_loss = CrossEntropyLoss(ignore_index=-1)(
                lang_prediction_scores.view(-1, self.config.vocab_size),
                masked_lm_labels.view(-1)
            )

Now this reduce the memory cost in a considerate way, allow me continue to train the model.

Now my question is:

  1. Why CrossEntropyLoss costs so many memory? It’s just a loss function, and ignore all index with -1 anyway. In my opinion, loss will be summed and then averaged. Should it not consume more memory in the backward process?
  2. Am I doing this correct? Does this have a difference in computing gradient than just use ignore_index?

Really need help here. :neutral_face:

To clear something up, in the example above, the input token is 72 always. I just need to compute CrossEntropyLoss on all tokens now. So this memory consumption is not caused by the MaxLength. I’m confident the number of tokens in loss function is primary cause here.

One of the reasons, I think about, is that the attention weights and so on within transformer model are doubled when the input length is doubled. Depending on the task, it may increase it by a factor of four.

As I mentioned above, the input length doesn’t change. The only thing changes is the number of tokens computed in CrossEntropyLoss.
That’s wierd.