Cross Entropy Loss error on image segmentation

I am a new user of Pytorch. I am adapting the Unet segmentation model, but I have an error in the evaluation of the Cross Entropy Loss function during training.

I used torch.utils.data.Dataset to build a specific dataset

train_data = DataLoaderSegmentation(train_path, mode=‘train’)
train_loader = DataLoader(train_data, batch_size = 4, shuffle=True, num_workers=4)

which works fine and load both the image and its mask in the format:

[batch,channel,W,H] , like [4,1,220,220], for both images and masks. I put channel to 1 since for background and foreground segmentation.

Parameters for train:

model = Unet(n_filters=32, n_class=1, input_channels=1)
device = ‘cuda’
LR = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
model = model.to(device)

During training I have an error on the marked line:

for data, target in train_loader:
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target) ############## ERROR

for data, target in train_loader:
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target) ######## ERROR


File “C:\Anaconda3\envs\envTorch\lib\site-packages\torch\nn\functional.py”, line 1840, in nll_loss
** ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)**

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘target’ in call to _thnn_nll_loss2d_forward


I see that target and output have the same dim [4,1,220,220] but the criterion(output, target) trow error.

Sorry, but after read some answer, I cannot resolve it yet.

Do output.float() and target.float() and see if it works.

@Megh_Bhalerao, It gives the same error.

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘target’ in call to _thnn_nll_loss2d_forward

Would the error come from the model?


loss = criterion(output.float(), target.float())
Traceback (most recent call last):

File “”, line 1, in
loss = criterion(output.float(), target.float())

File “C:\Anaconda3\envs\envTorch\lib\site-packages\torch\nn\modules\module.py”, line 541, in call
result = self.forward(*input, **kwargs)

File “C:\Anaconda3\envs\envTorch\lib\site-packages\torch\nn\modules\loss.py”, line 916, in forward
ignore_index=self.ignore_index, reduction=self.reduction)

File “C:\Anaconda3\envs\envTorch\lib\site-packages\torch\nn\functional.py”, line 2009, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

File “C:\Anaconda3\envs\envTorch\lib\site-packages\torch\nn\functional.py”, line 1840, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘target’ in call to _thnn_nll_loss2d_forward


log_softmax,

the last activation function in the model is a nn.Sigmoid(), no a nn.Softmax(), it could be?

Sorry, I meant try criterion(ouput.double(),target.double())
Let me know if this works

sorry, it doesn’t work.

I only changed the loss function to

criterion = torch.nn.BCEWithLogitsLoss()

and the error disappeared. However during training the accuracy showed is always low.

If you are dealing with a binary classification use case, a single output channel and nn.BCEWithLogitsLoss should be working.
Make sure to pass raw logits to the criterion (no sigmoid at the end).

nn.CrossEntropyLoss is usually applied for multi class classification/segmentation use cases, where you are dealing with more than two classes.
In this case, your target should be a LongTensor, should not have the channel dimension, and should contain the class indices in [0, nb_classes-1].

Since your accuracy is low, could you try to overfit a small data sample and see, if your model and training routine can successfully overfit this sample?

3 Likes

Thank you for your comments. I can’t make it work (I no tested more nn.BCEWithLogitsLoss, I focused in N=5 multi class segmentation). Images are [1,220,220] and their mask [5,220,220]. Each channel is a binary image with values 0 and 1, 1s for the object of interest on the respective channel, and 0s for the background. I don’t know if this mask form is correct.

The error occurs in the line

loss = criterion(output, target) ,

during the first pass. The error seems related to data type, but If I am not mistaken, internal tensors work with float numbers (all images and masks are floats).

Here is my code:

model = UnetModel(n_filters=32, n_class=5, input_channels=1)
device = ‘cuda’
model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()

epochs = 20
itr = 1
p_itr = 50
total_loss = 0
train_loss =
train_acc =

for epoch in range(epochs):
for data, target in train_loader:

    data = data.to(device)
    target = target.to(device)
    
    optimizer.zero_grad()
    output = model(data) 
    loss = criterion(output, target) ######## **Error**
    loss.backward()  
    optimizer.step() 
            
    total_loss += loss.item()        
    if itr%p_itr == 0:
        pred = torch.argmax(output, dim=1)
        correct = pred.eq(target)
        acc = torch.mean(correct.float())
        print('[Epoch {}/{}] Iteration {} -> Train Loss: {:.4f}, Accuracy: {:.3f}'.format(epoch+1, epochs, itr, total_loss/p_itr, acc))
        train_loss.append(total_loss/p_itr)
        train_acc.append(acc)
        total_loss = 0            
    itr += 1

The error is
File “C:\Anaconda3\envs\envTorch\lib\site-packages\torch\nn\functional.py”, line 1840, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 ‘target’ in call to _thnn_nll_loss2d_forward

Hi Joe!

The error (“scalar type Long”) and the documentation for
your loss function, CrossEntropyLoss, should serve to sort
things out.

As @ptrblck mentioned in one of his posts, above,
CrossEntropyLoss takes a LongTensor for its targets,
not a FloatTensor, hence the error.

CrossEntropyLoss (perhaps better called more explicitly
“categorical cross-entropy loss”) requires integer class labels
in [0, nClass - 1] for its targets. Provided the target in
your code is, indeed, class labels (as compared to probabilities,
or some such), it should suffice to convert it to a LongTensor,
e.g., loss = criterion (output, target.long()).

BCELossWithLogits does take a FloatTensor for its targets.
(It takes probabilities, not just 0,1 class labels.) So that’s why
you didn’t get the error when you tried BCELossWithLogits.

Good luck.

K. Frank

I am getting a new error.

Dear @KFrank.

  1. I remake the target as [batch, nClasses, H, W], being nClasses = 5. I split the objects labels along layers as follows (but I don’t know if it is correct):

Layer 0: image with all pixels value to zero, label 0 = background
Layer 1: pixels value to 1 for object 1 (e.g. dog), and 0s for the background
Layer 2: pixels value to 2 for object 2 (e.g. cat), and 0s for the background
and son on.

is it a reasonable model for the layers in semantic segmentation?

  1. However a get a new error:

RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [4, 5, 224, 224] ,

I use batchsize=5, nClasses=1. The input data is [4,1,224,244] aiming to obtain [4,5,224,244],

Hello Joe!

No, this is not right. Your target (“ground truth”) should be (a
LongTensor) of shape (batch, H, W). For any given sample
within the batch, the values of the “pixels” in your “H x W” “image”
should be the integer class labels in [0, 5).

You don’t use multiple “Layers” to specify which class a given
pixel is in – you just use an integer class label as the value of
that pixel.

Note, the input to CrossEntropyLoss is not of the same shape
as the target; the input is of shape
(batch, nClasses, H, W). For any given sample in the batch,
each pixel in the H x W input “image” has nClasses values. These
are the “logits” (“scores” in -infinity to +infinity) that tell you how likely
the model thinks a given pixel is to be in each of the nClasses
classes.

This error message looks reasonable in light of what I said above
(except that it looks like you are actually using batchsize = 4,
not batchsize = 5).

Assuming batchsize = 4, nClasses = 5, H = 224, and
W = 224, CrossEntropyLoss will be expecting the input
(prediction) you give it to be a FloatTensor of shape
(4, 5, 244, 244), and the target (ground truth) to be a
LongTensor of shape (4, 244, 244).

Good luck.

K. Frank

1 Like

Dear @KFrank you hit the nail, thank you. Thank you.

The target is a single image HxW, each pixel labeled as belonging to [0…nClasses-1]

loss = CrossEntropyLoss(input, target) in the specified size computes correctly (I was too wrong).

Can you suggest me some place to learn the keys of DL, I am starting on it. Thank you again.

1 Like