 Hi all,

My use case is a little weird:

I have 2 modules A and B after a parent module C.
During loss computation, I simply compute total_loss = loss_A + loss_B
But I get stuck while implementing backward for module C.

I have 2 gradients come from module A and module B, and I have 2 proposals:

2. backward module C with 2 gradients individually
Both 2 methods seem to have different downsides.

How would you guys solve this weird condition?
Any comment or direction would be a big help.
Thanks!

1 Like

summing the 2 gradients and then sending this summed gradients to C is mathematically the same as backwarding the gradients individually.

Hi smth,

You’re right, my brain must be broken …

But is simply summing the two gradients work in this case?
I mean, in high dimensional space, the 2 gradients may probably point to very different directions, so the resulting direction can be very different to original ones.
Or am I over-concerning this problem?

Thanks!

here’s a small test program to verify this:

``````import torch

# define initial data

# b is the parent module
b = a * 2

# rewrap variable to have manual history management here

c = b_ * b_
d = b_ * 4

e = c + d

# do backward in combined way
e.backward(torch.ones(e.size()))

# let's do separate way
b = a * 2

c = b_ * b_
d = b_ * 4

c.backward(torch.ones(c.size()))