I have a loss where each layer plays into the loss. Which is the correct approach in terms of making sure the weights are updated properly?
# option 1 x2 = self.layer1(x1) x3 = self.layer2(x2) x4 = self.layer3(x3)
In this option, I detach when feeding into each subsequent block
# option 2 # x2 = self.layer1(x1.detach()) # x3 = self.layer2(x2.detach()) # x4 = self.layer3(x3.detach())
shared ops which calculate 4 losses and sum them.
x4 = F.relu(self.bn1(x4)) loss = some_loss([x1, x2, x3, x4])