Making prediction with argmax

I was following this tutorial and was confused by the way we make final prediction. Let’s say the model is defined as

``````loss_func = F.cross_entropy

def get_model():
model = Mnist_Logistic()
return model, optim.SGD(model.parameters(), lr=lr)

model, opt = get_model()
print(loss_func(model(xb), yb))

for epoch in range(epochs):
for i in range((n - 1) // bs + 1):
start_i = i * bs
end_i = start_i + bs
xb = x_train[start_i:end_i]
yb = y_train[start_i:end_i]
pred = model(xb)
loss = loss_func(pred, yb)

loss.backward()
opt.step()

print(loss_func(model(xb), yb))
``````

. Then, the final prediction can be achieved by calling `torch.argmax(model(test_data), dim=1`). This means that `y_pred=model(test_data)` will return a vector that has a probability value for each digit. Then, my question is, how can PyTorch know index `0` of y_pred will correspond to the probability of the digit being `0`, index `1` will correspond to the probability of the digit being `1`, and so on? Since our `yb` is a one-dimensional vector that looks like `[5, 0, 4, ..., 8, 4, 8]` instead of a one-hot matrix, how can PyTorch know to link the index of `y_pred` with the correct label?

Thank you!

Hi Hiep!

The short answer is that `nn.functional.cross_entropy`
one-hots your class labels for you.

The number-one rule is that the output of your network means
whatever you train it to mean.

You are using `nn.functional.cross_entropy` as your loss
function. Your `yb` are integer class labels, one per sample.

Conceptually, implicitly under the hood, `cross_entropy`
one-hots your `yb`, implicitly softmaxes your `pred`, and then
implicitly calculates the cross-entropy of “softmax (pred)” and
“one-hot (yb)”.

The output of your model (the `y_pred`) should be understood
as logits. They (implicitly) get turned into probabilities when
`cross_entropy` (implicitly) softmaxes them. So you are training
your model to output (for each sample) a vector of length nClass,
where the value for index i is the logit (sort of like the probability)
of that sample being of class i. (Finally, you take the `argmax` of
your prediction vector. This finds the index of the logit with the
largest value – that is the index that your model predicts as having
the highest probability of being the class label, and you take this
as being the predicted class label.)

“PyTorch knows to link the index of `y_pred` with the correct label”
because you trained your network to do so.

Good Luck.

K. Frank

1 Like