How to include output specific layers in multilabel model?

Hey everyone,

I am trying to build a multilabel model with 5 classes. I would like to experiment with class specific layers, e.g. have 3 fully connected hidden dense layers and then 2 hidden layers per class that are not connected to the last two hidden layers of the other classes.

How can i implement this in pytorch? I tried googling but am not sure what to search for.
My guess would be that you actually cannot implement this in one network but would have the fully connected layers be one module and then stack on as many distinct modules as you have classes (re-using the output of the finall fully connected layer). autograd should still be able to figure out how to train everything.

As I understand, you would like to have a common net path for all classes, which is then split into separate paths.

Here is a small code snippet, which might help you:

class Net(nn.Module):
    def __init__(self, in_features):
        super(Net, self).__init__()
        # Create "shared" module
        self.main = nn.Sequential(
            nn.Linear(in_features, 10),
            nn.Linear(10, 10))
        # Create class-specific modules
        self.fc_c0 = nn.Linear(10, 1)
        self.fc_c1 = nn.Linear(10, 1)
        self.fc_c2 = nn.Linear(10, 1)
        self.fc_c3 = nn.Linear(10, 1)
        self.fc_c4 = nn.Linear(10, 1)
    def forward(self, x):
        x = F.relu(self.main(x))
        out_c0 = self.fc_c0(x)
        out_c1 = self.fc_c1(x)
        out_c2 = self.fc_c2(x)
        out_c3 = self.fc_c3(x)
        out_c4 = self.fc_c4(x)
        out =, out_c1, out_c2, out_c3, out_c4), dim=1)
        return F.log_softmax(out)

model = Net(10)

criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

X = np.random.randn(100, 10).astype(np.float32)
y = np.random.randint(0, 5, size=(100, ), dtype=np.int64)
data = Variable(torch.from_numpy(X))
target = Variable(torch.from_numpy(y))

output = model(data)
loss = criterion(output, target)



I hope you can use it as a starter for your net.

1 Like

Awesome thx! I started building something similar with two different classes but essentially this is what I had planned, Great to know that you can just do to combine the output of the different subnetworks.