How does one get the predicted classification label from a pytorch model?

the main thing is that you have to reduce/collapse the dimension where the classification raw value/logit is with a max and then select it with a .indices. Usually this is dimensions 1 since dim 0 has the batch size e.g. [batch_size,D_classification] where the raw data might of size [batch_size,C,H,W]

A synthetic example with raw data in 1D as follows:

import torch
import torch.nn as nn

# data dimension [batch-size, D]
D, Dout = 1, 5
batch_size = 16
x = torch.randn(batch_size, D)
y = torch.randint(low=0,high=Dout,size=(batch_size,))

mdl = nn.Linear(D, Dout)
logits = mdl(x)
print(f'y.size() = {y.size()}')
# removes the 1th dimension with a max, which is the classification layer
# which means it returns the most likely label. Also, note you need to choose .indices since you want to return the
# position of where the most likely label is (not it's raw logit value)
pred = logits.max(1).indices
print(pred)

print('--- preds vs truth ---')
print(f'predictions = {pred}')
print(f'y = {y}')

acc = (pred == y).sum().item() / pred.size(0)
print(acc)

output:


y.size() = torch.Size([16])
tensor([3, 1, 1, 3, 4, 1, 4, 3, 1, 1, 4, 4, 4, 4, 3, 1])
--- preds vs truth ---
predictions = tensor([3, 1, 1, 3, 4, 1, 4, 3, 1, 1, 4, 4, 4, 4, 3, 1])
y = tensor([3, 3, 3, 0, 3, 4, 0, 1, 1, 2, 1, 4, 4, 2, 0, 0])
0.25
1 Like