nn.CrossEntropyLoss
can be used for a multi-class classification and expects raw logits as the model output so you should remove the last torch.softmax
activation.
Also, remove the last F.relu
and return the output of self.fc2(x)
directly.
You could use forward hooks as described here or you could keep this method but reuse it in forward
to avoid duplicated code:
def forward(self, x):
x = self.last_hidden_layer_output(x)
x = self.fc2(x)
return x