I have a loss where each layer plays into the loss. Which is the correct approach in terms of making sure the weights are updated properly?
# option 1
x2 = self.layer1(x1)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
In this option, I detach when feeding into each subsequent block
# option 2
# x2 = self.layer1(x1.detach())
# x3 = self.layer2(x2.detach())
# x4 = self.layer3(x3.detach())
shared ops which calculate 4 losses and sum them.
x4 = F.relu(self.bn1(x4))
loss = some_loss([x1, x2, x3, x4])