Hi, I am not sure how to train my architecture.
I have input x and each batch consist of several parts. Each part goes through a different model. At the end, parts are merged together and passed through simple CNN.
My model forward method is like:
forward(self, x): #switch layers and batch x = x.permute(1, 0, 2, 3, 4, 5) x0 = self.model0(x) x1 = self.model1(x) x2 = self.model2(x) #merge layers together x = torch.stack([x0, x1, x2], dim=1) #threshold to 0 / 1 x = (x > 0.5).float() xInd = torch.argmin(x, dim=4) #calculate merged data based on argmin, self.threshold is custom value in <0, 1> xMerge= ((xInd - 1) * 2 + 1) * self.threshold xMerge = self.cnn(xMerge) return [x0, x1, x2, xMerge]
I want to train the network to not only predict xMerge, but also have the loss for layers. In my custom loss function, I use multiple
BCEWithLogitsLoss. To calculate, I do this:
class MultiBceLoss(nn.Module): __init__(self, layersCount): nn.Module.__init__(self) self.layersCount = layersCount self.bce_loss = nn.BCEWithLogitsLoss() forward(self, output, gt): loss = self.bce_loss(output[-1], gt[-1]) for i in range(0, self.layersCount): loss += self.bce_loss(output[i], gt[I]) return loss
However, the network learns incorrectly.
xMerge gradient from layers is loss, since it is a newly generated output from
argmin. How to train this type of network? Can it be done with a single loss?