I’ve sum troubles getting a MLP classifier to train with MSE loss for some reason. Maybe it’s too late in the day and I am overlooking sth, but I am wondering how autograd can be made compatible if the model outputs are a [num_examples, num_classes] matrix. I.e., each column has a probability for a given example to belong to a particular class.
This is no issue with cross_entropy() as it takes one hot encoded labels as input naturally.
However, when I switch to MSE, I get issues with the MSE loss (probably understandbly so because of the argmax) but I am wondering if there’s some work-around for:
logits, probas = model(features)
cost = F.mse_loss(torch.argmax(probas, 1).float(), targets.float())
If you are wondering why MSE loss? It’s purely for teaching purposes. I want to double-check whether my manually implemented networks train correctly (i.e., get equivalent results compared to autograd).
For reference, I want to compare sth like this (MLP with only 1 hidden layer, sigmoid activations, and MSE loss) to some automatic version in PyTorch
class MultilayerPerceptron():
def __init__(self, num_features, num_hidden, num_classes):
super(MultilayerPerceptron, self).__init__()
self.num_classes = num_classes
# hidden 1
self.weight_1 = torch.zeros(num_hidden, num_features,
dtype=torch.float).normal_(0.0, 0.1)
self.bias_1 = torch.zeros(num_hidden, dtype=torch.float)
# output
self.weight_o = torch.zeros(num_classes, num_hidden,
dtype=torch.float).normal_(0.0, 0.1)
self.bias_o = torch.zeros(num_classes, dtype=torch.float)
def forward(self, x):
# hidden 1
# input dim: [n_hidden, n_features] dot [n_features, n_examples] .T
# output dim: [n_examples, n_hidden]
z_1 = torch.mm(x, self.weight_1.t()) + self.bias_1
a_1 = torch.sigmoid(z_1)
# hidden 2
# input dim: [n_classes, n_hidden] dot [n_hidden, n_examples] .T
# output dim: [n_examples, n_classes]
z_2 = torch.mm(a_1, self.weight_o.t()) + self.bias_o
a_2 = torch.sigmoid(z_2)
return a_1, a_2
def backward(self, x, a_1, a_2, y):
#########################
### Output layer weights
#########################
# onehot encoding
y_onehot = torch.FloatTensor(y.size(0), self.num_classes)
y_onehot.zero_()
y_onehot.scatter_(1, y.view(-1, 1).long(), 1)
# Part 1: dLoss/dOutWeights
## = dLoss/dOutAct * dOutAct/dOutNet * dOutNet/dOutWeight
## where DeltaOut = dLoss/dOutAct * dOutAct/dOutNet
## for convenient re-use
# input/output dim: [n_examples, n_classes]
dloss_da2 = 2.*(a_2 - y_onehot) / y.size(0)
# input/output dim: [n_examples, n_classes]
da2_dz2 = a_2 * (1. - a_2) # sigmoid derivative
# output dim: [n_examples, n_classes]
delta_out = dloss_da2 * da2_dz2 # "delta (rule) placeholder"
# gradient for output weights
# [n_examples, n_hidden]
dz2__dw_out = a_1
# input dim: [n_classlabels, n_examples] dot [n_examples, n_hidden]
# output dim: [n_classlabels, n_hidden]
dloss__dw_out = torch.mm(delta_out.t(), dz2__dw_out)
dloss__db_out = torch.sum(delta_out, dim=0)
#################################
# Part 2: dLoss/dHiddenWeights
## = DeltaOut * dOutNet/dHiddenAct * dHiddenAct/dHiddenNet * dHiddenNet/dWeight
# [n_classes, n_hidden]
dz2__a1 = self.weight_o
# output dim: [n_examples, n_hidden]
dloss_a1 = torch.mm(delta_out, dz2__a1)
# [n_examples, n_hidden]
da1__dz1 = a_1 * (1. - a_1) # sigmoid derivative
# [n_examples, n_features]
dz1__dw1 = x
# output dim: [n_hidden, n_features]
dloss_dw1 = torch.mm((dloss_a1 * da1__dz1).t(), dz1__dw1)
dloss_db1 = torch.sum((dloss_a1 * da1__dz1), dim=0)
return dloss__dw_out, dloss__db_out, dloss_dw1, dloss_db1