Run out of GPU memory when generating sequences using multi-head attention model in an autoregressive way

I am working on time series forecasting using GPT-like model. The idea is similar to training a language model that, given the beginning part of a sentence, the model can generate the rest of that sentence.

“I’m hungry, I want a hamburger and a cup of cola.”

Input: I’m hungry

Predict: I want a hamburger and a cup of cola.

An autoregressive language model will generate words step by step.

I’m hungry, I
I’m hungry, I want
I’m hungry, I want a
I’m hungry, I want a hamburger and a cup of cola.

That is, the newly generated word will be appended to the end of the previous input sequence to construct the new input sequence. During training, I will calculate the loss on the generated content “I want a hamburger and a cup of cola” and use back-propagation to update model parameters. The generation process can be implemented through a for-loop and a “decoder-only” module.

However, the GPU memory usage always spikes in this for-loop, and causes out-of-GPU-memoery error. If I set “@torch.no_grad()”, there will be no such problem. So I guess maybe the problem is caused by stored intermediate data for back-propagation.

Do you think my implementation is the right way to generate word sequences? Do you have any suggestions for optimizing my implementation?

My time series forecasting sequence contains around 100 elements, that is, the for-loop repeat operation 100 times.

The increase in memory is expected as you are storing all computation graphs including all intermediate tensors needed for the gradient computation.
I don’t know if your auto-regressive approach is suitable for your problem, but you won’t be able to reduce the memory usage as the memory is indeed needed. Reducing the number of iterations, the model size, etc. might help but also changes your use case.

1 Like

There is a PyTorch tutorial “Language Translation with nn.Transformer and torchtext”. As the decoder does essentially the same as for you, i.e., generating text, maybe this implementation might be worth look it. If you look at the greedy_decode() method, there’s the for loop generating the next word. I ran it locally and it works just fine. To be fair, the translations – and with that the number of iterations in greedy_decode() are – much shorter than 100 steps.