Hi, I am not sure how to train my architecture.
I have input x and each batch consist of several parts. Each part goes through a different model. At the end, parts are merged together and passed through simple CNN.
My model forward method is like:
forward(self, x):
#switch layers and batch
x = x.permute(1, 0, 2, 3, 4, 5)
x0 = self.model0(x[0])
x1 = self.model1(x[1])
x2 = self.model2(x[2])
#merge layers together
x = torch.stack([x0, x1, x2], dim=1)
#threshold to 0 / 1
x = (x > 0.5).float()
xInd = torch.argmin(x, dim=4)
#calculate merged data based on argmin, self.threshold is custom value in <0, 1>
xMerge= ((xInd - 1) * 2 + 1) * self.threshold
xMerge = self.cnn(xMerge)
return [x0, x1, x2, xMerge]
I want to train the network to not only predict xMerge, but also have the loss for layers. In my custom loss function, I use multiple BCEWithLogitsLoss
. To calculate, I do this:
class MultiBceLoss(nn.Module):
__init__(self, layersCount):
nn.Module.__init__(self)
self.layersCount = layersCount
self.bce_loss = nn.BCEWithLogitsLoss()
forward(self, output, gt):
loss = self.bce_loss(output[-1], gt[-1])
for i in range(0, self.layersCount):
loss += self.bce_loss(output[i], gt[I])
return loss
However, the network learns incorrectly. xMerge
gradient from layers is loss, since it is a newly generated output from argmin
. How to train this type of network? Can it be done with a single loss?