Hi Yongjie!
Your process it mostly correct. You need to use retain_graph = True
in your first backward()
call:
loss1.backward (retain_graph = True)
No, one forward pass is enough. (You do, of course, have to calculate
both loss1
and loss2
.) But you do need two backward()
passes.
(If you weren’t to use retain_graph = True
in your first backward()
pass, you would have to perform a second forward pass to rebuild the
computation graph for use by the second backward()
pass.)
Here is an illustrative script:
import torch
print (torch.__version__)
weights1A = None # net1 weights after net1 half-update
weights2A = None # net2 weights after net2 half-update
weights1B = None # net1 weights after full update
weights2B = None # net2 weights after full update
for updateMethod in range (3): # update model with two "half-updates" vs. one full update
torch.manual_seed (2021) # make initialization reproducible
net1 = torch.nn.Linear (2, 2) # create new, freshly-initialized model
net2 = torch.nn.Linear (2, 2)
input = torch.randn (2)
target1 = torch.randn (2)
target2 = torch.randn (2)
opt1 = torch.optim.SGD (net1.parameters(), lr = 0.1)
opt2 = torch.optim.SGD (net1.parameters(), lr = 0.2)
if updateMethod == 0: # net1 half-update
output = net2 (torch.tanh (net1 (input)))
loss1 = torch.nn.MSELoss() (output, target1)
loss1.backward()
opt1.step()
weights1A = net1.weight.detach().clone()
elif updateMethod == 1: # net2 half-update
output = net2 (torch.tanh (net1 (input)))
loss2 = torch.nn.MSELoss() (output, target2)
loss2.backward()
opt2.step()
weights2A = net2.weight.detach().clone()
elif updateMethod == 2: # perform full update with a single forward pass
output = net2 (torch.tanh (net1 (input)))
loss1 = torch.nn.MSELoss() (output, target1) # net1 update
loss1.backward (retain_graph = True) # keep graph for net2 loss2.backward()
opt1.step()
weights1B = net1.weight.detach().clone()
opt2.zero_grad() # zero out net2 grads populated by loss1.backward()
loss2 = torch.nn.MSELoss() (output, target2) # net2 update
loss2.backward()
opt2.step()
weights2B = net2.weight.detach().clone()
print ('weights1A = ...')
print (weights1A)
print ('weights1B = ...')
print (weights1B)
print ('check weights1:', torch.allclose (weights1A, weights1B))
1
print ('weights2A = ...')
print (weights2A)
print ('weights2B = ...')
print (weights2B)
print ('check weights2:', torch.allclose (weights2A, weights2B))
Here is its output:
1.9.0
weights1A = ...
tensor([[-0.4652, 0.0495],
[ 0.3326, 0.2997]])
weights1B = ...
tensor([[-0.4652, 0.0495],
[ 0.3326, 0.2997]])
check weights1: True
weights2A = ...
tensor([[-0.6444, 0.6545],
[-0.2909, -0.5669]])
weights2B = ...
tensor([[-0.6444, 0.6545],
[-0.2909, -0.5669]])
check weights2: True
Best.
K. Frank