I have a rather complicated use case, which concerns a number of frameworks & models, but I am asking here, because it seemed the most appropriate, I hope that is okay.
The gist of the problem:
I am trying to create a new architecture for a multimodal language model, based on Llama 7b. This requires me to send the output of a second model through a projection layer, the output of which I use to edit the input embeddings of Llama. The resulting tensor I send into the Llama model as usual. The issue is that doing this causes a CUDA out of memory error a long way down the line in Llama (see the stack trace below). The attached minimal example reproduces the bug in a significantly simpler fashion, what I am doing in my own model is rather complicated but it’s conceptually similar.
What I have found out so far:
I suspect that the issue might have to do with the gradients which are created when the tensor passes the projection layer. If I .detach() the tensor after the projector but before editing the input embeddings and only use ones which are computed later there is no problem. (But obviously I still need said gradients to train the projection layer.)
Analyzing memory consumption through nvitop and torch.cuda.memory_summary() suggests that the stack trace is accurate, i.e. that there is a massive memory spike at around the point the error occurs. (Meaning the GPU has ~60GB of memory free before I send the tensor through the model.) At least to me memory consumption looks normal around the point where I edit the tensor.
Specs / memory considerations:
I use 8xA100 à 80 GB VRAM for my main script, this example obviously only uses one of said A100s, but I think that this should still be enough memory for one forward pass.
Example to reproduce:
from transformers import AutoModel, AutoTokenizer import torch # Normal model stuff # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") tokenizer.pad_token = tokenizer.eos_token tokenized_prompt = tokenizer( "This is some silly prompt.", padding="max_length", truncation=True, max_length=4096, return_tensors="pt" ).to("cuda") model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to("cuda") generative_embedding_layer = model.get_input_embeddings() input_embeddings = generative_embedding_layer(tokenized_prompt["input_ids"]) # Editing the embedding tensor with some projected tensor # fake_projector = torch.nn.Linear(512, 512).to("cuda") some_model_output = torch.ones(512,512).to("cuda") projected_model_output = fake_projector(some_model_output) input_embeddings[:, -512:, -512:] = projected_model_output attention_mask = tokenized_prompt["attention_mask"] attention_mask[:, -512:] = 1 # Send through model # outputs = model.forward( inputs_embeds=input_embeddings, attention_mask=attention_mask )
Traceback (most recent call last): File "/raid/marie/PLACEHOLDER/minimal_example.py", line 36, in <module> outputs = model.forward( File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 925, in forward layer_outputs = decoder_layer( File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 635, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/mbauer/miniconda3/envs/open_llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 373, in forward attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 79.15 GiB total capacity; 75.22 GiB already allocated; 1.13 GiB free; 77.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF