My use case is a little weird:
I have 2 modules A and B after a parent module C.
During loss computation, I simply compute total_loss = loss_A + loss_B
But I get stuck while implementing backward for module C.
I have 2 gradients come from module A and module B, and I have 2 proposals:
- sum the 2 gradients
- backward module C with 2 gradients individually
Both 2 methods seem to have different downsides.
How would you guys solve this weird condition?
Any comment or direction would be a big help.