Forward Hook CUDA out of memory despite detach(), gc.collect() and cuda.empty_cache()

Hi, as a research project, we are currently trying to explore the feature space of the gpt-neox LLM.
Therefore we are trying to extract layer activations using a forward hook.
Unforutnately we are running out of memory after a couple of inferences.
I already searched through the forum and found possible solutions, but I am still unable to fix the problem.

When removing the hook, the gpu does not run out of memory.
We do use detach() and torch.cuda.empty_cache() to no avail (also tried gc.collect() after the inference).

Any insight is highly appreciated!
Thanks in advance!

Here is a snippet:

activations = {}
# hook courtesy of ptrblck https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/5
#function that gets passed to the forward hook and extracts layer activations
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()[0]
    return hook

def inference(prompt, temp = 0.9, max_length = 20, do_sample = True):

    # tokenizing the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    # generating the output
    gen_tokens = model.generate(
        input_ids.to(0),
        do_sample=do_sample,
        temperature=temp,
        max_length=max_length,
    )

    # decoding the generated text
    return tokenizer.batch_decode(gen_tokens)[0]

model_name = "EleutherAI/gpt-neox-20b"
weights_path = "path_to_weights"
if not os.path.exists(weights_path):
    os.makedirs(weights_path)
    save_weights(fp16=False,bf16=True)

tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
config.use_cache = False
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)

if torch.cuda.is_bf16_supported():
    device_map = infer_auto_device_map(model, no_split_module_classes=["GPTNeoXLayer"],dtype=torch.bfloat16)
    load_checkpoint_and_dispatch(
        model,
        weights_path,
        device_map=device_map,
        offload_folder=None,
        offload_state_dict=False,
        dtype="bfloat16"
    )
else:
    device_map = infer_auto_device_map(model, no_split_module_classes=["GPTNeoXLayer"],dtype=torch.float16)
    load_checkpoint_and_dispatch(
        model,
        weights_path,
        device_map=device_map,
        offload_folder=None,
        offload_state_dict=False,
        dtype="float16"
    )

for n in range(0, int(len(model.gpt_neox.layers))): # int(len(model.gpt_neox.layers)/2) - 2, , int(len(model.gpt_neox.layers)/2)
    model.gpt_neox.layers[n].mlp.dense_4h_to_h.register_forward_hook(get_activation(f"layer_{n}_activation"))

model.eval()

for index, row in tqdm(df_train.iterrows(), total=len(df_train)):
    prompt = row[5]
    prompts_list.append(prompt)
    labels.append(row[0])
    # tokenizing the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    
    # decoding and printing the generated text
    gen_text = inference(prompt, max_length=len(prompt.split())+20)
    gen_text_list.append(gen_text)
    only_gen_text.append(gen_text[len(prompt.split())+1:])
    store_activations.append(copy.deepcopy(activations)) 
    output_sent = nlp(gen_text)
    input_sent = nlp(prompt)
    output_sentiment.append(output_sent.cats)
    input_sentiment.append(input_sent.cats)
    
    torch.cuda.empty_cache()

Iā€™m not familiar with your model but are you already executing the forward pass in a with torch.no_grad() otr with torch.inference_mode() guard?
If so, are you seeing an unexpected increase in memory or is the memory increase in each iteration corresponding to e.g. the size of:

store_activations.append(copy.deepcopy(activations)) 
1 Like

Yes, with store_activations.append(copy.deepcopy(activations)) we stored a new copy of activations on the gpu for each iteration.

Storing activations on the CPU memory avoids the GPU running out of memory.

def hook(model, input, output):
        activations[name] = output.detach().cpu()[0]

This and saving the activations to a file after some iterations to avoid filling up the CPU memory did the trick.

Thank you very much!