CrossEntropyLoss crashes on soft targets in various PyTorch versions

I am trying to pass soft targets to mobilenet v2 (with the only change that I am using 2 classes instead of 1000).
According to CrossEntropyLoss — PyTorch 1.13 documentation and various similar posts, since ~2022 float targets are accepted by CrossEntropyLoss. However, for me it still fails.

The Last layer is:

nn.Linear(in_features=1280, out_features=2, bias=True)

It crashes though, if my labels are floats. I have prepared the minimal failing example using values returned by my network (batch_size = 4):

activations = torch.FloatTensor([[-0.3139, -0.0486],
[-0.0510,  0.0470],
[ 0.0963,  0.0143],
[-0.2151, -0.0576]])

targets_float = torch.FloatTensor([1.0, 1.0, 0.0, 0.5])
targets_long = torch.LongTensor([1, 1, 0, 0])
crit = nn.CrossEntropyLoss()
crit(activations,targets_long) # returns tensor(0.6606) as expected
crit(activations,targets_float) 
# return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) #crashes with
# RuntimeError: expected scalar type Long but found Float

As described in the docs the shape for “soft” targets is expected to be the same as the model output:

If containing class probabilities, same shape as the input and each value should be between [0,1].

This would thus work:

activations = torch.FloatTensor([[-0.3139, -0.0486],
[-0.0510,  0.0470],
[ 0.0963,  0.0143],
[-0.2151, -0.0576]])

targets_float = torch.FloatTensor([1.0, 1.0, 0.0, 0.5])
targets_long = torch.LongTensor([1, 1, 0, 0])
crit = nn.CrossEntropyLoss()
crit(activations,targets_long) # returns tensor(0.6606) as expected
crit(activations,F.one_hot(targets_long, num_classes=2).float())

Thanks a lot, it works. I thought at first that you mean that the batch size must be 1 (it’s 4 in the example), but if my targets are [4,2] the softmax is computed fine, so I am doing smth like this for now:

targets_float = torch.FloatTensor([[t, 1-t] for t in targets_float])
since F.one_hot would crash on vector that has floats.