Multiple model outputs activate from a single model based on labels selection

A simple example is given in below,

class my_model(nn.Module):
     def __init__(self):
         super(my_model,self).__init__()
         self.all_layers = nn.ModuleList()
         for i in range(10):
             layers = []
             layers.append(nn.Linear(10,10))
             layers.append(nn.BatchNorm1d(10))
             self.main = nn.Sequential(*layers)
             self.all_layers.append(self.main)  
      def forward(self,zn,x):
          output = []
          for i in range(10):
              mask = x==i
              temp = zn[mask]
              print(temp.size())
              output.append(self.all_layers[i](temp))
            
          return output



model = my_model()
print(model)
target = torch.randn(60,10)
loss =nn.MSELoss()
input = torch.randn(60,10)
labels = torch.randint(0,10,(60,))
zn = torch.randn(60,10)
predict = model(zn, labels)
print(len(predict))
print(predict)
criteria = loss(predict,target) 

TypeError: expected Tensor as element 0 in argument 0, but got list

I’m not exactly sure, how the code is supposed to work.
Currently you have some issues in your code:

  • self.main is never used
  • you recreate layers = [] inside the loop, and overwrite self.all_layers with the last two layers
  • the returned output list will throw the mentioned error, since the criterion expects tensors not lists
  • even if you try to torch.stack the output (which would be the usual workflow), you will most likely encounter errors, since you are using a mask to index zn, which will create variable sized outputs

Thank you for your kind reply.
The idea is simple.
let say a,b,c are the same model architecures (eg, simple a nn.Linear layer). These three models are trained on the specific classes samples from a single batch. The final output is d = a+b+c (all the indiviual model outputs are mutually exclusives) i.e. a is trained on first 15 samples, b is trained on next 20 samples and c is trained on last 15 samples. The final loss (sum of indiviual losses) backward for updating the respective model parameters. In keras, we can create a single model with 3 different dictionaries. Each dictionary model can be updated by specific training samples of the current batch i.e. all the individual dictionary models are trained separately based on specific training samples.
For the pytorch solution, I have created a main model in which the 3 submodels (a,b,c) are assigned by nn.ModuleList(). That means, all the indiviual models are sharing the same structure but mutually exclusive in nature. Now, the mask is defined for selecting the models from the main models. The final output passed through nn.mse loss. If I update the optimizer steps, the models will be updated based on selecting samples. Am I do something wrong? Please tell me. If I am wrong then how do I solve the problem?

Thanks for the clarification!
In that case you would have to deal with some edge cases:

  • if the current batch does not contain samples from a specific class, you should skip it. Otherwise you will get a all-zero mask and the code will raise an exception. Maybe you could use torch.unqiue to get all current class labels and loop over these indices instead of the loop for all classes
  • if the current batch only contains a single sample of a class, your nn.BatchNorm1d layer(s) will throw an exception, since they cannot calculate the batch stats for a single sample (and 1 sequence length).

However, if you somehow make sure that your batches are balanced, your code works fine:

model = my_model()
print(model)
target = torch.randn(60,10)
loss =nn.MSELoss()
input = torch.randn(60,10)
labels = torch.arange(10).view(10, 1).repeat(1, 6).view(-1)
zn = torch.randn(60,10)
predict = model(zn, labels)
output = torch.cat(predict)
criteria = loss(output,target) 

Thank you for your kind reply.
torch.unique is a good idea to use instead of class labels. I never thought about that.
" if the current batch only contains a single sample of a class, your nn.BatchNorm1d layer(s) will throw an exception, since they cannot calculate the batch stats for a single sample (and 1 sequence length)." exactly. but the model will be trained on the balanced dataset.
I have a fundamental question regarding the loss backward. The final loss is the summations of all losses. I think scalar loss value will update the parameters of individual models that are connected by the closed graph of specific input labels. Is the assumption correct?

I think your explanation is correct, although I’m not sure, what “closed” graph means exactly.
However, each operation in the forward pass will create a computation graph, which will be used to calculate the gradients for the involved parameters in the backward call.

Thank you for your reply.
“closed graph” means that model “a” parameters will update based on 15 samples, model ‘b’ parameters update on the next 20 samples and model ‘c’ parameters update on the last 25 samples respectively. Although, the final loss is summations of all 3 models losses and loss gradient backward to update the individual model parameters. that means model ‘b’ would not depend on the first 15 samples and the last 25 samples respectively (because model b parameters don’t have closed connection between two sample spaces). This is my understanding. Maybe I am wrong.