I have a rather complicated use case, which concerns a number of frameworks & models, but I am asking here, because it seemed the most appropriate, I hope that is okay.
The gist of the problem:
I am trying to create a new architecture for a multimodal language model, based on Llama 7b. This requires me to send the output of a second model through a projection layer, the output of which I use to edit the input embeddings of Llama. The resulting tensor I send into the Llama model as usual. The issue is that doing this causes a CUDA out of memory error a long way down the line in Llama (see the stack trace below). The attached minimal example reproduces the bug in a significantly simpler fashion, what I am doing in my own model is rather complicated but it’s conceptually similar.
What I have found out so far:
I suspect that the issue might have to do with the gradients which are created when the tensor passes the projection layer. If I .detach() the tensor after the projector but before editing the input embeddings and only use ones which are computed later there is no problem. (But obviously I still need said gradients to train the projection layer.)
Analyzing memory consumption through nvitop and torch.cuda.memory_summary() suggests that the stack trace is accurate, i.e. that there is a massive memory spike at around the point the error occurs. (Meaning the GPU has ~60GB of memory free before I send the tensor through the model.) At least to me memory consumption looks normal around the point where I edit the tensor.
Specs / memory considerations:
I use 8xA100 à 80 GB VRAM for my main script, this example obviously only uses one of said A100s, but I think that this should still be enough memory for one forward pass.
Example to reproduce:
from transformers import AutoModel, AutoTokenizer
import torch
# Normal model stuff #
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenized_prompt = tokenizer(
"This is some silly prompt.",
padding="max_length",
truncation=True,
max_length=4096,
return_tensors="pt"
).to("cuda")
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to("cuda")
generative_embedding_layer = model.get_input_embeddings()
input_embeddings = generative_embedding_layer(tokenized_prompt["input_ids"])
# Editing the embedding tensor with some projected tensor #
fake_projector = torch.nn.Linear(512, 512).to("cuda")
some_model_output = torch.ones(512,512).to("cuda")
projected_model_output = fake_projector(some_model_output)
input_embeddings[:, -512:, -512:] = projected_model_output
attention_mask = tokenized_prompt["attention_mask"]
attention_mask[:, -512:] = 1
# Send through model #
outputs = model.forward(
inputs_embeds=input_embeddings,
attention_mask=attention_mask
)
Error Trace:
Traceback (most recent call last):
File "/raid/marie/PLACEHOLDER/minimal_example.py", line 36, in <module>
outputs = model.forward(
File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 925, in forward
layer_outputs = decoder_layer(
File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 635, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 373, in forward
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 79.15 GiB total capacity; 75.22 GiB already allocated; 1.13 GiB free; 77.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF