MSE Loss with one-hot encoded outputs

I’ve sum troubles getting a MLP classifier to train with MSE loss for some reason. Maybe it’s too late in the day and I am overlooking sth, but I am wondering how autograd can be made compatible if the model outputs are a [num_examples, num_classes] matrix. I.e., each column has a probability for a given example to belong to a particular class.

This is no issue with cross_entropy() as it takes one hot encoded labels as input naturally.

However, when I switch to MSE, I get issues with the MSE loss (probably understandbly so because of the argmax) but I am wondering if there’s some work-around for:

    logits, probas = model(features)
    cost = F.mse_loss(torch.argmax(probas, 1).float(), targets.float())

If you are wondering why MSE loss? It’s purely for teaching purposes. I want to double-check whether my manually implemented networks train correctly (i.e., get equivalent results compared to autograd).

For reference, I want to compare sth like this (MLP with only 1 hidden layer, sigmoid activations, and MSE loss) to some automatic version in PyTorch

class MultilayerPerceptron():

    def __init__(self, num_features, num_hidden, num_classes):
        super(MultilayerPerceptron, self).__init__()
        
        self.num_classes = num_classes
        
        # hidden 1
        self.weight_1 = torch.zeros(num_hidden, num_features, 
                                    dtype=torch.float).normal_(0.0, 0.1)
        self.bias_1 = torch.zeros(num_hidden, dtype=torch.float)
        
        # output
        self.weight_o = torch.zeros(num_classes, num_hidden, 
                                    dtype=torch.float).normal_(0.0, 0.1)
        self.bias_o = torch.zeros(num_classes, dtype=torch.float)
        
    def forward(self, x):
        # hidden 1
        
        # input dim: [n_hidden, n_features] dot [n_features, n_examples] .T
        # output dim: [n_examples, n_hidden]
        z_1 = torch.mm(x, self.weight_1.t()) + self.bias_1
        a_1 = torch.sigmoid(z_1)

        # hidden 2
        # input dim: [n_classes, n_hidden] dot [n_hidden, n_examples] .T
        # output dim: [n_examples, n_classes]
        z_2 = torch.mm(a_1, self.weight_o.t()) + self.bias_o
        a_2 = torch.sigmoid(z_2)
        return a_1, a_2

    def backward(self, x, a_1, a_2, y):  
    
        #########################
        ### Output layer weights
        #########################
        
        # onehot encoding
        y_onehot = torch.FloatTensor(y.size(0), self.num_classes)
        y_onehot.zero_()
        y_onehot.scatter_(1, y.view(-1, 1).long(), 1)
        

        # Part 1: dLoss/dOutWeights
        ## = dLoss/dOutAct * dOutAct/dOutNet * dOutNet/dOutWeight
        ## where DeltaOut = dLoss/dOutAct * dOutAct/dOutNet
        ## for convenient re-use
        
        # input/output dim: [n_examples, n_classes]
        dloss_da2 = 2.*(a_2 - y_onehot) / y.size(0)

        # input/output dim: [n_examples, n_classes]
        da2_dz2 = a_2 * (1. - a_2) # sigmoid derivative

        # output dim: [n_examples, n_classes]
        delta_out = dloss_da2 * da2_dz2 # "delta (rule) placeholder"

        # gradient for output weights
        
        # [n_examples, n_hidden]
        dz2__dw_out = a_1
        
        # input dim: [n_classlabels, n_examples] dot [n_examples, n_hidden]
        # output dim: [n_classlabels, n_hidden]
        dloss__dw_out = torch.mm(delta_out.t(), dz2__dw_out)
        dloss__db_out = torch.sum(delta_out, dim=0)
        

        #################################        
        # Part 2: dLoss/dHiddenWeights
        ## = DeltaOut * dOutNet/dHiddenAct * dHiddenAct/dHiddenNet * dHiddenNet/dWeight
        
        # [n_classes, n_hidden]
        dz2__a1 = self.weight_o
        
        # output dim: [n_examples, n_hidden]
        dloss_a1 = torch.mm(delta_out, dz2__a1)
        
        # [n_examples, n_hidden]
        da1__dz1 = a_1 * (1. - a_1) # sigmoid derivative
        
        # [n_examples, n_features]
        dz1__dw1 = x
        
        # output dim: [n_hidden, n_features]
        dloss_dw1 = torch.mm((dloss_a1 * da1__dz1).t(), dz1__dw1)
        dloss_db1 = torch.sum((dloss_a1 * da1__dz1), dim=0)

        return dloss__dw_out, dloss__db_out, dloss_dw1, dloss_db1
1 Like

I think you should also pass the predictions as a hot encoded tensor.
If you pass the argmax as the prediction, it will be internally broadcasted to something like this (line of code):

num_features, num_hidden, num_classes = 5, 20, 10
batch_size = 10

features = torch.randn(batch_size, num_features, requires_grad=True)
targets = torch.randint(0, num_classes, (batch_size, ))
targets = F.one_hot(targets, num_classes=num_classes)

model = MultilayerPerceptron(num_features, num_hidden, num_classes)

logits, probas = model(features)

# Internally what F.mse_loss will do
ret = (torch.argmax(probas, 1).float() - targets.float()) ** 2

print(targets)
> tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
print(probas)
> tensor([[0.4905, 0.4872, 0.5295, 0.5871, 0.3829, 0.4792, 0.5026, 0.5062, 0.4536,
         0.5444]], grad_fn=<SigmoidBackward>)
print(torch.argmax(probas, 1).float())
> tensor([3.])
print(ret)
> tensor([[4., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])  # !!!

# Without the argmax it should be working I think
ret = (probas - targets.float()) ** 2
print(ret)
> tensor([[0.2596, 0.2374, 0.2803, 0.3447, 0.1466, 0.2296, 0.2526, 0.2563, 0.2058,
         0.2964]], grad_fn=<PowBackward0>)

Also, I’m not sure if you update the parameters manually, but if you would like to use an optimizer, I had to wrap all internal parameters in nn.Parameter.

Let me know, if this helps or if I’m misunderstanding something.

1 Like

Ah, thanks a lot! I was somehow conflating the gradient computations with the prediction accuracy computations :stuck_out_tongue:

1 Like

I’m glad it’s working!
Looking forward to the next lecture :wink:

1 Like