Train model with multiple independent parts and manualy modified tensor

Hi, I am not sure how to train my architecture.

I have input x and each batch consist of several parts. Each part goes through a different model. At the end, parts are merged together and passed through simple CNN.

My model forward method is like:

forward(self, x):
   #switch layers and batch
   x = x.permute(1, 0, 2, 3, 4, 5)

   x0 = self.model0(x[0])
   x1 = self.model1(x[1])
   x2 = self.model2(x[2])

   #merge layers together
   x = torch.stack([x0, x1, x2], dim=1)

   #threshold to 0 / 1
   x = (x > 0.5).float()
   xInd = torch.argmin(x, dim=4)  
   #calculate merged data based on argmin, self.threshold is custom value in <0, 1>
   xMerge= ((xInd - 1) * 2 + 1) * self.threshold 

   xMerge = self.cnn(xMerge)

   return [x0, x1, x2, xMerge]

I want to train the network to not only predict xMerge, but also have the loss for layers. In my custom loss function, I use multiple BCEWithLogitsLoss. To calculate, I do this:

class MultiBceLoss(nn.Module):
    __init__(self, layersCount):
       nn.Module.__init__(self) 
       self.layersCount = layersCount
       self.bce_loss = nn.BCEWithLogitsLoss()

    forward(self, output, gt):
       loss = self.bce_loss(output[-1], gt[-1])
       for i in range(0, self.layersCount):
           loss += self.bce_loss(output[i], gt[I])
       return loss

However, the network learns incorrectly. xMerge gradient from layers is loss, since it is a newly generated output from argmin. How to train this type of network? Can it be done with a single loss?

I’m not sure which part exactly you want to train, but note that you are detaching x and xMerge from x0, x1, x2 since you are using non-differentiable operations such as (x > 0.5).float().

I want to train model0, model1, model2 and cnn together.

This won’t be possible since the modelX are already detached by the threshold operation as well as the argmin.
You would thus have to use a differentiable / “soft” threshold operation as described e.g. here by @KFrank.

1 Like