Memory leak when computing high-order gradients using hooks

Recently I need to double backpropagate on the gradients of the embedding layers for NLP tasks.

I essentially have 2 ways of doing it. One with autograd.grad, another with register_backward_hook on the embedding layer.

But here is the thing,

If I just use autograd.grad to get the gradient with respect to the embedding layer, let’s call this W’, then I can call backward on W’ without any memory leak. The problem is that this is not exactly what I wanted. It could be off because if there are the same tokens in the sentence, the gradient returned for that token would be the sum of the gradients of that token in multiple occurrences.

So the only way to do it as far as I know, in this case, would be to backward on the gradients captured by a hook. This is achieved by setting a hook at the embedding layer using register_backward_hook, and then call either .backward(create_graph=True) or autograd.grad(embedding_layer, loss, create_graph=True) . Then the gradients would be captured by the hook and they are actually separated rather than summed up for the same tokens. However, this incurs a memory leak of the GPU.

Here is like a minimized code that causes it:

import torch
import os
import numpy as np
np.random.seed(0)
torch.manual_seed(0)
os.environ['CUDA_VISIBLE_DEVICES']="1"
dev = torch.device('cuda')

embedding_gradient = []
def hook_layer(module, grad_in, grad_out):
    embedding_gradient.append(grad_out[0])
arr = np.random.randint(10, size=3)
random_ix = torch.LongTensor(arr).to(dev)
print(arr)
embedding_layer = torch.nn.Embedding(10, 5, padding_idx=0).to(dev)
net = torch.nn.Linear(5,1).to(dev)
for i in range(1000):
    print(i)
    print(torch.cuda.memory_summary(device=0, abbreviated=True))
    #1 set the hook to embedding layer
    embedding_gradient = []
    hook = embedding_layer.register_backward_hook(hook_layer)
    #2 forward pass
    embeds = embedding_layer(random_ix)
    out = net(embeds)
    #3 backward pass
    summed = out.norm(2)
    summed.backward(create_graph=True)
    #4 remove the hook
    hook.remove()
    final = embedding_gradient[0].sum()
    final.backward()

if I use grad_auto = torch.autograd.grad(summed, embedding_layer.weight, create_graph=True) rather than summed.backward(create_graph=True) , the memory leak goes away.

But it is not the case in my actual code. In my code, either way there would be a memory leak. It would be great if anyone knows how to solve this issue or kowns how to circumvent it. I am using 1.5 torch, cuda 10.0 with a GTX 1080ti GPU. This is also reproducible on 1.4. Thanks in advance!

1 Like

I think the memory leak of this code comes from that you are constantly appending gradient tensors to the embedding_gradient list.
Since you do not free it after calculation, it remains in the GPU memory.

However, I’m not quite sure about what you try to achieve here.
First, calling gradient.backward() before optimizer.step() is treating the gradient as your loss function.
So you are actually training the parameters to minimize the gradient, is this what you want?
Second, I think that in most NLP tasks it’s common to accumulate the gradient of multiple occurrences for a single token, since a token need information from multiple positions to update its representation.
Yet if you think you really need to retain gradients from specific positions, you should block the gradients by customize a torch.autograd.Function instead of using a hook.

1 Like

I tried to delete the gradient tensors in my embedding_gradient list each iteration, and added a deltorch.cuda.empty_cache() afterward for both the toy example and my actual code. Sadly it didn’t solve the problem. I am not really sure where the issue is.

Yes, you are correct here. Making all the gradient low is actually my goal here. If that’s the case, then you could see that why accumulating the gradients would not be my best bet. This is because my objective is to derive all gradients closer to 0. So I need to either take the absolute value of it or square it. However, the operation becomes different if what I get is a sum of gradients, since |a|+|b| is not the same as |a+b|.

Customizing my own torch.autograd.Function seems like an interesting direction. How exactly do I go about that? Do I define something similar to torch.autograd.grad? Is there any example you can point me to? Thanks a lot for the response btw!

Have you tried to detach the position of tokens you don’t want to have gradient?
The detach() operation makes them not accumulate the gradients in the backward path.
Just detach the position you want to block the gradient after the embedding layer.
This may be easier done than customizing autograd function.

This could be a good strategy if I want a subset of the gradients right? Currently I want the gradients for all the tokens since I want all gradients to be low. Also, if I want to run in batches, the same tokens’ gradients in the same batch also get summed up. I don’t know if running detach() would be suitable in this case?

If you want to train the gradient to be low by minimizing it as an objective, retrieve the gradient by torch.autograd.grad then minimize its norm.

But if you want to block gradients for certain positions, then you should detach the positions where you don’t want the gradient to be accumulated.
I don’t know why you need to block the gradient for the same token, since the embedding layer is designed to receive different information from multiple positions.

I guess you misunderstand that the gradient of each token will be summed multiple times for each token?
Different occurrences of the same token does not result in the same gradient.

I see what you are saying. I am aware that different occurrences of the same token does not result in the same gradient. Lets say I have a sentence abca, then the gradient of a at different positions would have different values. And the gradient of token a at embedding layer Wa=Wa1 + Wa2, where Wa1 is the gradient of the first a token w.r.t the embedding layer and similarly for Wa2 is the gradient of the second a token.

You are saying that I can basically do sum(norm(Wa) + norm(Wb) + norm(Wc)) or something similar and minimize that. For the sake of simplicity, I will define norm() as sum(abs(Wa)**2). You can see why norm(Wa) would be different from norm(Wa1) + norm(Wa2). For this specific application, I thought normalizing them separately is more correct. So that’s why I want to have sum(norm(Wa1) + norm(Wa2) + norm(Wb) + norm(Wc)).

I see…, so you basically want to separate the gradient from different tokens and minimize them respectively.

I think you can just take the gradient with respect to the output of the embedding layer then minimize it?

That’s a good point. Currently I am struggling to test it though cause I can’t find a reference to the output of the embedding layer. I am using huggingface’s BERT model, and the call is

embedding_layer = self._model.bert_model.embeddings.word_embeddings.weight
embedding_gradients_auto = torch.autograd.grad(loss, embedding_layer,create_graph=True)

self._model.bert_model.embeddings.word_embeddings above is just an nn.Embedding module as shown here: link to huggingface’s code. Currently it is taking the gradient w.r.t. the weight of the nn.Embedding. How exactly do I reference the output? Right now I am gonna try to use forward hook.

Actually, this seems to work. If I use forward hook, and then call autograd on it, then the output is correct and it doesn’t incur memory leak. Anyway, thanks a lot for the help! Really appreciate it.

1 Like