Activation and loss function for multi dimensional one hot encoded output

I have a multi dimensional output model with the shape of (B,C,T) before the softmax layer. Its target is a row wise one hot encoded matrix with the same shape of model prediction ie (B,C,T) . The trouble is PyTorch softmax method doesn’t working for row wise one hot encoded values. I wrote this sample code to show that the output value after the softmax layer is not row summed to one.

import torch

# shape (Batch, Class, Time Step)
target = torch.tensor([[[0,1,0,0],
                        [1,0,0,0],
                        [0,0,0,1]],
                        
                        [[1,0,0,0],
                        [0,0,1,0],
                        [0,1,0,0]],
                         
                        [[0,0,1,0],
                        [0,0,0,1],
                        [1,0,0,0]]],dtype=torch.float)

input = torch.nn.Softmax(1)(target)

assert torch.all(target[0].sum(1).to(torch.float)==input[0].sum(1).to(torch.float)),"Failed: sum is not equal to one"

loss = torch.nn.BCELoss()(input,target)

What is a suitable activation and loss function for this kind of multi dimensional one hot encoded data?

nn.Softmax(dim=1) will create an output, where output.sum(dim=1) returns ones as seen in your code snippet:

input = torch.nn.Softmax(1)(target)
print(input.sum(1))
> tensor([[1.0000, 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000, 1.0000]])

However, nn.BCELoss is used for a binary classification/segmentation or a multi-label classification/segmentation use case, where zero, one or multiple classes can be active per sample.

Based on your target it seems you are dealing with a multi-class classification, where only one class is active per sample.
In that case, you should use nn.CrossEntropyLoss, which expects raw logits as the model output and a target as a LongTensor containing the class indices in the range [0, nb_classes-1].
To create the target, use target = torch.argmax(target, dim=2).

1 Like

This doesn’t working

t =  torch.argmax(target, dim=2)
torch.nn.CrossEntropyLoss()(input,target)

ValueError: Expected target size (3, 4), got torch.Size([3, 3, 4])

The output of torch.argmax is t, while you are still passing target to the criterion.
Could you pass t to it and rerun the code?

sorry that was my mistake the real code and error is this

input = torch.nn.Softmax(1)(target)
t =  torch.argmax(target, dim=2)
torch.nn.CrossEntropyLoss()(input,t)

ValueError: Expected target size (3, 4), got torch.Size([3, 3])

nn.CrossEntropyLoss expects raw logits, so you should not use a softmax on the input tensor.

Since the one-hot encoding is in dim2 in the target tensor and deals with 4 classes, your input should thus have the shape [3, 4, 3].
This code snippet shows its usage:

target = torch.tensor([[[0,1,0,0],
                        [1,0,0,0],
                        [0,0,0,1]],
                        
                        [[1,0,0,0],
                        [0,0,1,0],
                        [0,1,0,0]],
                         
                        [[0,0,1,0],
                        [0,0,0,1],
                        [1,0,0,0]]],dtype=torch.float)

input = torch.randn(3, 4, 3)
criterion = nn.CrossEntropyLoss()
target = torch.argmax(target, dim=2)
loss = criterion(input, target)