Backward() to compute partial derivatives without retain_graph= True

Dear all, i am implementing the paper GradNorm, I have found 2 implementations but both of them require a lot of memory.

Let’s suppose i have a loss L and a network W:
L = loss1 + loss2 + lossN
W are the weights with grad = True

I want to do :
(dloss1)/dW
save gradient norm
(dloss2)/dW
save gradient norm
(dL)/dW
update parameters

An example of implementation can be found at:

the interesting lines are:

    # compute and retain gradients
    total_weighted_loss.backward(retain_graph=True)
    
    # GRADNORM - learn the weights for each tasks gradients
    
    # zero the w_i(t) gradients since we want to update the weights using gradnorm loss
    self.weights.grad = 0.0 * self.weights.grad
    
    W = list(self.model.mtn.shared_block.parameters())
    norms = []
    
    for w_i, L_i in zip(self.weights, task_losses):
        # gradient of L_i(t) w.r.t. W
        gLgW = torch.autograd.grad(L_i, W, retain_graph=True)
        
        # G^{(i)}_W(t)
        norms.append(torch.norm(w_i * gLgW[0]))
    
    norms = torch.stack(norms)

BUT I want to avoid retain_graph = True as this leads to out of memory for my network.

A similar and very good question has been asked already, but has no answer:

Is there a way to efficiently perform derivatives w.r.t. other losses without retain_graph?

  • e.g., calling backward() on copies of W (W1,W2) w.r.t L1, then L2…
  • e.g., multidimensional loss L = [l1,l2,lTOT] and then L.backward(Tensor) with Tensor somehow pointing to the right loss [1,0,0], [0,1,0],[0,0,1]

Additionaly, would it be possible to clarify how to inpsect and visualize which are the tensor responsible for memory leaks?

@albanD @colensbury could you help? your opinions are very valuable

Many thanks for your help,
Stefano

Hi,

All the losses seem to share the same computation graph. So to be able to use that graph multiple times when you do multiple backward calls, you have to keep it around with retain_graph=True. I don’t see any way around this but reduce the memory usage by reducing batchsize.

thanks so much for your prompt reply!
I understand what you mean. So this is one of the cases where “retain_graph = True” cannot be avoided.
However for me this does not work. Any batchsize at a certain point of the training (few to many k iterations leads to GPU OOM, for a model which requires 3 Gb on average without retain graph.

I have tried to avoid all memory leaks, and still the GPU is full when using retain graph. Why does the network grow so much in size? and why does it keep track of all previous iterations of the network? if I can ask?

Additional question: what is the recommended way to analyze tensor and graph memory usage for debugging?

Thanks a lot!

The memory leak is most likely not related to retain_graph but other changes you make. The only different retain_graph makes is that it delays the deletion of some buffers until the graph is deleted.
So the only way to these to leak is if you never delete the graph. But if you never delete it, even without retain_graph, you would end up running out of memory.

I would check using a package like torchviz to make sure the graph does not grow forever. If it does, then you need to fix your code to avoid that (most likely missing .detach()).

1 Like

amazing! thanks for your help!

Hi Stefano,

Sorry for the delayed reply!

It was a long time ago that I experimented with this and I haven’t pursued the issue much further but here is a workaround that helped for me, but this is for a 2 tasks example only (I wouldn’t speak of memory “leak” by the way, retain_graph=True just requires lots of memory.)

So suppose your total loss is:

loss = w_1 * loss_1 + w_2 * loss_2

# step A: calculate gradient (and its magnitude) of FIRST loss wrt a certain layer "manually", WITH retain_graph=True

gw_1 = torch.autograd.grad(
    loss_1,
    (
        myNet.<whichever layer is the last common one>.conv.weight,
        myNet.<whichever layer is the last common one>.conv.bias,
    ),
    retain_graph=True,
)

G_W_1 = torch.sqrt(torch.sum(gw_1[0] ** 2) + torch.sum(gw_1[1] ** 2))

# step B: BACKWARD through ENTIRE network
optimizer.zero_grad()
loss.backward()
optimizer.step()

# step C: now calculate gradient of SECOND loss by simply SUBTRACTING the FIRST gradient
# from the SUM that was calculated via the backward on the TOTAL loss
# (alternative would be another autograd.grad on loss_2 with retain_graph=True followed by the loss.backward,
# but that's where memory problems started for me)

gw_2 = []
gw_2.append(myNet.<whichever layer is the last common one>.conv.weight.grad - gw_1[0])
gw_2.append(myNet.<whichever layer is the last common one>.conv.bias.grad - gw_1[1])

G_W_2 = torch.sqrt(torch.sum(gw_2[0] ** 2) + torch.sum(gw_2[1] ** 2))

G_W_mean = 0.5 * (G_W_1 + G_W_2)
G_W_mean_no_grad = G_W_mean.detach()

Then you continue to work with these two quantities, G_W_mean, G_W_1 and G_W_2 to follow the loss-weight update strategy for w_1 and w_2 as described in the paper.

I guess this only works for 2 task problems though.

Anyway, I hope it helped a little.

Cheers,

Bernie

1 Like

Dear @albanD , after experimenting a bit I still have some doubts:

  • I have used torchviz as suggested but in my case it is not helping. I have 2 “models”: one proper model which is my network, and one two dimensional nn.Parameter to store the learnt losses weights.

You suggested to check if the graph is growing in size with torchviz, but model is working properly and not growing in size. When inspecting the nn.Parameters, torchviz only shows one box with none.

However it seems that torchviz is of no help for visualizing the autograd backward graph. The only method I found online is get_grad(), but it is verbose.

How can I visualize and understand the backward graph to understand what is generating out of memory? It seems very hard to inspect what it going on under the hood, mostly in unconventional applications such as gradNorm.

Many thanks,
Stefano

Hi,

graphviz shows you exactly the backward graph.
nn.Parameter are always leaf Tensors (with no history), so it is expected that you get a single box with None when plotting their history as they never have any history.

1 Like