How to add a custom classifier block to a pre-trained Network using nn.Module?

I’m still fairly new to PyTorch, and grappling with this problem for my own understanding: I am trying to insert a new untrained classifier block at the end of densenet 121, replacing the existing one and defining it using nn.Module instead of nn.Sequential.

It works fine if I do something along these lines:

classifier = nn.Sequential(
                      nn.Linear(1024, 500),
                      nn.ReLU(),
                      nn.Linear(500, 1),
                      nn.Sigmoid())

I’d like to try and define a nn.Module from scratch instead. In this case I was trying something like:

class New_Classifier(nn.Module):
    def __init__(self):
        super(New_Classifier, self).__init__()
        self.fc1 = nn.Linear(feature_length,hidden_1)
        self.fc2 = nn.Linear(hidden_1, hidden_2)
        self.fc3 = nn.Linear(hidden_2, 1)   # should output the final score
        # dropout layer
        self.dropout = nn.Dropout(0.2)

def forward(self, x):
    x = F.relu(self.fc1(x)) # fully connected layer 1
    x = self.dropout(x)
    x = F.relu(self.fc2(x)) # fully connected layer 2
    x = torch.sigmoid(self.fc3(x))   # Return a value between 0 and 1
    return x

model.classifier = New_Classifier

I realise this isn’t very elegant or necessary, I just wanted to understand how one would go about this if the new classifier wasn’t nicely definable with nn.Sequential or similar. I get the following error:

TypeError: cannot assign ‘main.New_Classifier’ as child module ‘classifier’ (torch.nn.Module or None expected)

Any suggestions what caused this error? I suspect the ‘child module’ is a clue. Or if I’m trying to do something that just can’t be done?

1 Like

Replace the line model.classifier = New_Classifier with model.classifier = New_Classifier() (just append () at the end) so you actually create an instance of the module.

Thanks very much Mariosasko, that was awesome. I threw in the () and it worked flawlessly! Now I’m a little embarrassed I missed that!