I have 2 modules A and B after a parent module C.
During loss computation, I simply compute total_loss = loss_A + loss_B
But I get stuck while implementing backward for module C.
I have 2 gradients come from module A and module B, and I have 2 proposals:
sum the 2 gradients
backward module C with 2 gradients individually
Both 2 methods seem to have different downsides.
How would you guys solve this weird condition?
Any comment or direction would be a big help.
Thanks!
Thanks for your reply!
You’re right, my brain must be broken …
But is simply summing the two gradients work in this case?
I mean, in high dimensional space, the 2 gradients may probably point to very different directions, so the resulting direction can be very different to original ones.
Or am I over-concerning this problem?
import torch
from torch.autograd import Variable
# define initial data
a = Variable(torch.randn(10), requires_grad=True)
# b is the parent module
b = a * 2
# rewrap variable to have manual history management here
b_ = Variable(b.data, requires_grad=True)
c = b_ * b_
d = b_ * 4
e = c + d
# do backward in combined way
e.backward(torch.ones(e.size()))
b.backward(b_.grad)
agrad_combined = a.grad.data.clone()
# now reset a's grad
# reset a's grad
a.grad.data.zero_()
# let's do separate way
b = a * 2
b_ = Variable(b.data, requires_grad=True)
c = b_ * b_
d = b_ * 4
c.backward(torch.ones(c.size()))
b.backward(b_.grad)
b_.grad.data.zero_()
d.backward(torch.ones(d.size()))
b.backward(b_.grad)
agrad_separate = a.grad.data.clone()
# print difference between combined method and separate method
print(agrad_combined - agrad_separate)