OOM memory due to computation graph when stacking embeddings

Hi,

I am trying out the following code which stacks embeddings for manipulation later.

from transformers import BartTokenizer, BartModel,BartForConditionalGeneration
import torch


def encode():

    print("Loading model")
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
    model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
    model.to("cuda")

    encoer = model.get_encoder() #returns object of BARTEncoder

    sentences_to_encode = [str(i)+" This is my paragraph."*50  for i in range(50)]
    tensorStack = []
    for idx, para in enumerate(sentences_to_encode):
        inputs = tokenizer.encode(para,return_tensors='pt',truncation=True,padding=True)
        tensorStack.append(encoer(inputs.to("cuda")).last_hidden_state)
        print("Encoded para "+str(idx),"Mem:", round(torch.cuda.memory_allocated()/1024**3,2),"GB")


if __name__ == '__main__':
    encode()

When stacking tensors that are the output of the Transformer encoder, I observe an increase in GPU memory after each iteration. I believe this is due to the storage of the computation graph.

Is there a way to prevent the creation of new intermediate tensors on GPU for each iteration and tie their weights or do checkpointing? I intend to perform backpropagation on these embeddings, so using torch.no_grad() is not a suitable solution for me.
Thanks

You might be look for torch.utils.checkpoint — PyTorch 2.2 documentation

1 Like

Thanks! let me check if I can work this out.
I was also thinking if I could distribute the encoder, but since it is the same reference of the encoder object going through paragraphs I am not entirely sure how to distribute computation to multiple gpus