CUDA OOM because of tensor gradients?

I have a rather complicated use case, which concerns a number of frameworks & models, but I am asking here, because it seemed the most appropriate, I hope that is okay.

The gist of the problem:
I am trying to create a new architecture for a multimodal language model, based on Llama 7b. This requires me to send the output of a second model through a projection layer, the output of which I use to edit the input embeddings of Llama. The resulting tensor I send into the Llama model as usual. The issue is that doing this causes a CUDA out of memory error a long way down the line in Llama (see the stack trace below). The attached minimal example reproduces the bug in a significantly simpler fashion, what I am doing in my own model is rather complicated but it’s conceptually similar.

What I have found out so far:
I suspect that the issue might have to do with the gradients which are created when the tensor passes the projection layer. If I .detach() the tensor after the projector but before editing the input embeddings and only use ones which are computed later there is no problem. (But obviously I still need said gradients to train the projection layer.)
Analyzing memory consumption through nvitop and torch.cuda.memory_summary() suggests that the stack trace is accurate, i.e. that there is a massive memory spike at around the point the error occurs. (Meaning the GPU has ~60GB of memory free before I send the tensor through the model.) At least to me memory consumption looks normal around the point where I edit the tensor.

Specs / memory considerations:
I use 8xA100 à 80 GB VRAM for my main script, this example obviously only uses one of said A100s, but I think that this should still be enough memory for one forward pass.

Example to reproduce:

from transformers import AutoModel, AutoTokenizer

import torch


#    Normal model stuff     #

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenized_prompt = tokenizer(
    "This is some silly prompt.",
    padding="max_length",
    truncation=True,
    max_length=4096,
    return_tensors="pt"
).to("cuda")

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to("cuda")

generative_embedding_layer = model.get_input_embeddings()
input_embeddings = generative_embedding_layer(tokenized_prompt["input_ids"])


#   Editing the embedding tensor with some projected tensor      #

fake_projector = torch.nn.Linear(512, 512).to("cuda")
some_model_output = torch.ones(512,512).to("cuda")
projected_model_output = fake_projector(some_model_output)

input_embeddings[:, -512:, -512:] = projected_model_output
attention_mask = tokenized_prompt["attention_mask"]
attention_mask[:, -512:] = 1

#   Send through model  #
outputs = model.forward(
    inputs_embeds=input_embeddings,
    attention_mask=attention_mask
)

Error Trace:


Traceback (most recent call last):
  File "/raid/marie/PLACEHOLDER/minimal_example.py", line 36, in <module>
    outputs = model.forward(
  File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 925, in forward
    layer_outputs = decoder_layer(
  File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 635, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 373, in forward
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 79.15 GiB total capacity; 75.22 GiB already allocated; 1.13 GiB free; 77.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Based on the stacktrace you are running out of memory in the forward pass, so I don’t think gradients are related to it. You could add debug print statements into the forward pass of your model to check which layer significantly increases the memory usage.
E.g. a forward activation, needed for the gradient computation, could be huge and use a lot of memory. Once you have an idea which layers (or rather activations) increase the memory usage significantly, you could consider offloading them to the CPU as described here.

Thanks a lot for your answer to my post!

I’ve tried analyzing the code with print statements for memory allocation and a debugger before making this post and it did not lead to much. I did not manage to detect any significant increases in memory usage before the crash happens. I also do not think that there is a problem with the activations ‘per se’ for the following reasons:

  1. Sending the same tensor through the model works perfectly fine and does not even come close to maxing out memory, as long as I detach it once after the nn.Linear. Which is the reason why I thought gradients might have something to do with it, even if it does not make much sense, given it’s the forward pass.

  2. I’ve tried simply giving the model more memory: but with Deepspeed parameter and gradient offloading it made a DGX node run out of memory, which seems like it should not happen with a 7B Llama model and one linear layer.

I don’t fully understand why your debugging didn’t yield to much as the stacktrace shows this operation is causing the OOM:

File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 373, in forward
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

So at least adding a memory check before the op should already show an almost fully occupied GPU memory. From there you could go backwards to check which op increased it the most.