I am trying to build a multilabel model with 5 classes. I would like to experiment with class specific layers, e.g. have 3 fully connected hidden dense layers and then 2 hidden layers per class that are not connected to the last two hidden layers of the other classes.
How can i implement this in pytorch? I tried googling but am not sure what to search for.
My guess would be that you actually cannot implement this in one network but would have the fully connected layers be one module and then stack on as many distinct modules as you have classes (re-using the output of the finall fully connected layer). autograd should still be able to figure out how to train everything.
Awesome thx! I started building something similar with two different classes but essentially this is what I had planned, Great to know that you can just do torch.cat to combine the output of the different subnetworks.