How to use one-hot vectors with PyTorch's loss functionss

I am trying to encode the alphabet into a numerical format and it appears there are two ways to go about this. Using class labels (0 1 2 … 25) or using a one-hot vector with each colum being the 26 letters. From the what I’ve read online, it would make sense to use a one-hot vector as there is no ordinality in letters e.g. A is not ‘close’ to B. However I am struggling to use one-hot vectors with PyTorch because all of the loss functions only take class labels. It is easy to convert a one-hot vector to class labels to be able to pass it to the loss functions, however would this hurt performance of the network? This is the architecture problem I am facing.

It may be worth metioning my situation. This is a many to one seq network using RNN layers and eventually the network goes through a linear layer with an output shape of 26.

This shouldn’t hurt the performance of the network, and in reality a one-hot formatting vs. a label wouldn’t be related to things like ordinality unless the loss function itself explicitly included ordinality in the loss computation (the typical cross-entropy loss doesn’t do this).

If you want to convert one-hot to class labels, a simple way to do this is using the max function.

>>> import torch
>>> onehot_labels = torch.zeros(64, 26, dtype=torch.long)
>>> for i in range(64):
...     onehot_labels[i,i%26] = 1
...
>>> _, class_labels = torch.max(onehot_labels, axis=1)
>>> print(class_labels)
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,
       10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,  0,  1,
        2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

Note that having the architecture of the network producing a one-hot output while the loss function itself is taking a class label is not a problem, but rather an implementation detail of how the labels are being stored and passed to the loss function.

1 Like

Thank you for the incredibly helpful answer. It makes me wonder, why you would use one-hot encoding in the first place? Storing class labels is much more space efficient.

My guess is for simplicity; as you noticed, the output of the network itself is one-hot encoded, so it can simplify things if this is the format that is used for the labels as well. In practice, storing labels this way on disk also wouldn’t really matter as compression will handle the long strings of zeros in one-hot encoded labels efficiently.