Making prediction with argmax

I was following this tutorial and was confused by the way we make final prediction. Let’s say the model is defined as

loss_func = F.cross_entropy

def get_model():
    model = Mnist_Logistic()
    return model, optim.SGD(model.parameters(), lr=lr)

model, opt = get_model()
print(loss_func(model(xb), yb))

for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        start_i = i * bs
        end_i = start_i + bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb)
        loss = loss_func(pred, yb)


print(loss_func(model(xb), yb))

. Then, the final prediction can be achieved by calling torch.argmax(model(test_data), dim=1). This means that y_pred=model(test_data) will return a vector that has a probability value for each digit. Then, my question is, how can PyTorch know index 0 of y_pred will correspond to the probability of the digit being 0, index 1 will correspond to the probability of the digit being 1, and so on? Since our yb is a one-dimensional vector that looks like [5, 0, 4, ..., 8, 4, 8] instead of a one-hot matrix, how can PyTorch know to link the index of y_pred with the correct label?

Thank you!

Hi Hiep!

The short answer is that nn.functional.cross_entropy
one-hots your class labels for you.

The number-one rule is that the output of your network means
whatever you train it to mean.

More directly to your question:

You are using nn.functional.cross_entropy as your loss
function. Your yb are integer class labels, one per sample.

Conceptually, implicitly under the hood, cross_entropy
one-hots your yb, implicitly softmaxes your pred, and then
implicitly calculates the cross-entropy of “softmax (pred)” and
“one-hot (yb)”.

The output of your model (the y_pred) should be understood
as logits. They (implicitly) get turned into probabilities when
cross_entropy (implicitly) softmaxes them. So you are training
your model to output (for each sample) a vector of length nClass,
where the value for index i is the logit (sort of like the probability)
of that sample being of class i. (Finally, you take the argmax of
your prediction vector. This finds the index of the logit with the
largest value – that is the index that your model predicts as having
the highest probability of being the class label, and you take this
as being the predicted class label.)

“PyTorch knows to link the index of y_pred with the correct label”
because you trained your network to do so.

Good Luck.

K. Frank

1 Like