So - upon checking other questions here and building on 2 replies -
@ptrblck How to Detach specific components in the loss? - #8 by ptrblck
and,
@albanD Is Loss.backward() function calculate gradients over mini-batch - #5 by albanD
It seems there are 2 possible approaches?
import torch
import torch.nn as nn
BS = 5 #batch size
modelA = nn.Linear(5,10)
modelB = nn.Linear(10, 2)
modelC = nn.Linear(10, 20)
x = torch.randn(BS, 5)
#Approach 1 - Clone to duplicate the computation Graph!
a = modelA(x)
a_copy = a.clone()
b = modelB(a_copy)
filter_index = torch.topk(b,1)[1].squeeze() #tensor([0, 1, 1, 1, 0])
a = a[filter_index.nonzero(),:].squeeze(1)
c = modelC(a)
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)
b.mean().backward(retain_graph = True)
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)
c.mean().backward()
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)
#Clear All:
modelA.weight.grad = None
modelB.weight.grad = None
modelC.weight.grad = None
#Approach 2 - Just retain_graph so it is not removed?
a = modelA(x)
b = modelB(a)
filter_index = torch.topk(b,1)[1].squeeze() #tensor([0, 1, 1, 1, 0])
a = a[filter_index.nonzero(),:].squeeze(1)
c = modelC(a)
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)
b.mean().backward(retain_graph = True)
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)
c.mean().backward()
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)
And, upon gradient checking, both seem to work fine! After the first .backward()
, modelA and modelB is updated and the second .backward()
updates modelA (further!) and modelC?
So, is there a potential difference between these 2 approaches? It seems - both require me to specify retain_graph
. Unless I do something like this -
#Approach 3:
loss = b.mean() + c.mean()
loss.backward()
And this seems to work as well?! Without the retain_graph
since the computation graph is combined (maybe?).
So - basically - Are all 3 approaches the same? and neither is better than the other? And, the a = a[filter_index.nonzero(),:].squeeze(1)
step to discretely and partially filter the modelA output-batch poses no issue in breaking the computation graph?
PS: I checked this reply about memory efficiency using batch gradients accumulation here - Using sub-batches to avoid busting the memory? - #2 by albanD and if multiple batches can be evaluated (.backward()
) for a single optimizer-step to accumulate gradients then a part of the batch output can be used to evaluate the model as well right?