How does one get the predicted classification label from a pytorch model?

I am aware the code is (How to predict only one test sample in pytorch model?):

pred = torch.max(, 1)

as seen on the cifar10 beginner tutorial. But I’d like to understand why that’s correct.

I was reading the docs ( and it’s still not entirely clear to me. This is the sentence that doesn’t make sense to me:

Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim.

I don’t understand what the max value of a row mean given a specific slice of a tensor. Rows only exist in 2D tensors (matrices), at least to me. Can you clarify?

Note that I do understand what the dim is suppoe to do (see:

ref/motivation, understand this clearly: Calculating accuracy of the current minibatch?


PyTorch modules such as Conv or Linear, only accept batched data, so if you have a single image you still have to create batches of size 1. So, not matter what type of data you are working with, you at least have 2D tensor for input and obviously, your model in the simplest case will generate a prob/logit for each sample in each batch where in case of a single instance in a batch, it would be [1, number_of_classes]. Now, taking max will give you value of max and index of max for each row which corresponds to each sample in each batch.



the main thing is that you have to reduce/collapse the dimension where the classification raw value/logit is with a max and then select it with a .indices. Usually this is dimensions 1 since dim 0 has the batch size e.g. [batch_size,D_classification] where the raw data might of size [batch_size,C,H,W]

A synthetic example with raw data in 1D as follows:

import torch
import torch.nn as nn

# data dimension [batch-size, D]
D, Dout = 1, 5
batch_size = 16
x = torch.randn(batch_size, D)
y = torch.randint(low=0,high=Dout,size=(batch_size,))

mdl = nn.Linear(D, Dout)
logits = mdl(x)
print(f'y.size() = {y.size()}')
# removes the 1th dimension with a max, which is the classification layer
# which means it returns the most likely label. Also, note you need to choose .indices since you want to return the
# position of where the most likely label is (not it's raw logit value)
pred = logits.max(1).indices

print('--- preds vs truth ---')
print(f'predictions = {pred}')
print(f'y = {y}')

acc = (pred == y).sum().item() / pred.size(0)


y.size() = torch.Size([16])
tensor([3, 1, 1, 3, 4, 1, 4, 3, 1, 1, 4, 4, 4, 4, 3, 1])
--- preds vs truth ---
predictions = tensor([3, 1, 1, 3, 4, 1, 4, 3, 1, 1, 4, 4, 4, 4, 3, 1])
y = tensor([3, 3, 3, 0, 3, 4, 0, 1, 1, 2, 1, 4, 4, 2, 0, 0])
1 Like

To understand reductions see this and related links:

# dimension
Dimension reduction. It collapses/reduces a specific dimension by selecting an element from that dimension to be
Consider x is 3D tensor. x.sum(1) converts x into a tensor that is 2D using an element from D1 elements in
the 1th dimension. Thus:
x.sum(1) = x[i,k] = op(x[i,:,k]) = op(x[i,0,k],...,x[i,D1,k])
the key is to realize that we need 3 indices to select a single element. So if we use only 2 (because we are collapsing)
then we have D1 number of elements possible left that those two indices might indicate. So from only 2 indices we get a
set that we need to specify how to select. This is where the op we are using is used for and selects from this set.
In theory if we want to collapse many indices we need to indicate how we are going to allow indexing from a smaller set
of indices (using the remaining set that we'd usually need).

import torch

x = torch.tensor([
     [1, 2, 3],
     [4, 5, 6]

print(f'x.size() = {x.size()}')

# sum the 0th dimension (rows). So we get a bunch of colums that have the rows added together.
x0 = x.sum(0)

# sum the 1th dimension (columns)
x1 = x.sum(1)

x_1 = x.sum(-1)

x0 = x.max(0)

y = torch.tensor([[
         [ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])


# into the screen [1, 13]
# columns [1, 5, 9]
# rows [1, 2, 3, 4]

# for each remaining index, select the largest value in the "screen" dimension
y0 = y.max(0)