Inconsistent behaviour of cross entropy loss without weights

Hi, I noticed that the output of cross-entropy loss (for semantic segmentation use case so K-dimensional one) with reduction="mean" is different than when I calculate it with sum and mean on unreduced output. This is most visible with a bigger batch size.

A minimal working example:

import torch
import torch.nn as nn
import numpy as np

basic_img = torch.Tensor([arr for arr in np.random.rand(128, 2, 768, 768)])
label_basic = torch.Tensor([arr for arr in np.random.randint(2, size=(128,  768, 768))]).long()

criterion = nn.CrossEntropyLoss(reduction='none', weight=None)
loss = criterion(basic_img, label_basic)
loss.mean()

>>> tensor(0.7137)

criterion = nn.CrossEntropyLoss(reduction='mean', weight=None)
loss = criterion(basic_img, label_basic)
loss

>>> tensor(1.6724)

As far as I understand, these two versions should be equal. Am I missing something?

Pytorch version: 1.8.1

Thank you :slight_smile:

1 Like

Hi Dominika!

The short answer is that this looks like a bug.

@ptrblck Maybe a bug?

I can reproduce this on pytorch version 1.7.1.

When I use double precision, or the gpu, or smaller tensors, the issue
seems to go away. Naively, using reduction = 'none' and then
taking the mean() seems to give the “right” answer.

(I don’t think this can be explained away as a legitimate consequence
of floating-point round-off error.)

Here is my test script:

import torch
print (torch.__version__)

_ = torch.manual_seed (2021)

nBatch = 128
nClass = 2
height = 768
width = 768

input = torch.randn (nBatch, nClass, height, width)
target = torch.randint (nClass, (nBatch, height, width))

print ('nBatch =', nBatch, '\nnClass =', nClass, '\nheight =', height, '\nwidth =', width)
print ('big, float:')
print ('mean:', torch.nn.CrossEntropyLoss() (input, target))
print ('none:', torch.nn.CrossEntropyLoss (reduction = 'none') (input, target).mean())
print ('sum:', torch.nn.CrossEntropyLoss (reduction = 'sum') (input, target) / target.numel())

print ('big, cuda:')
print ('mean:', torch.nn.CrossEntropyLoss() (input.cuda(), target.cuda()))
print ('none:', torch.nn.CrossEntropyLoss (reduction = 'none') (input.cuda(), target.cuda()).mean())
print ('sum:', torch.nn.CrossEntropyLoss (reduction = 'sum') (input.cuda(), target.cuda()) / target.numel())

print ('big, double:')
print ('mean:', torch.nn.CrossEntropyLoss() (input.double(), target))
print ('none:', torch.nn.CrossEntropyLoss (reduction = 'none') (input.double(), target).mean())
print ('sum:', torch.nn.CrossEntropyLoss (reduction = 'sum') (input.double(), target) / target.numel())


nBatch = 8
nClass = 2
height = 16
width = 16

input = torch.randn (nBatch, nClass, height, width)
target = torch.randint (nClass, (nBatch, height, width))

print ('nBatch =', nBatch, '\nnClass =', nClass, '\nheight =', height, '\nwidth =', width)
print ('small, float:')
print ('mean:', torch.nn.CrossEntropyLoss() (input, target))
print ('none:', torch.nn.CrossEntropyLoss (reduction = 'none') (input, target).mean())
print ('sum:', torch.nn.CrossEntropyLoss (reduction = 'sum') (input, target) / target.numel())

print ('small, cuda:')
print ('mean:', torch.nn.CrossEntropyLoss() (input.cuda(), target.cuda()))
print ('none:', torch.nn.CrossEntropyLoss (reduction = 'none') (input.cuda(), target.cuda()).mean())
print ('sum:', torch.nn.CrossEntropyLoss (reduction = 'sum') (input.cuda(), target.cuda()) / target.numel())

print ('small, double:')
print ('mean:', torch.nn.CrossEntropyLoss() (input.double(), target))
print ('none:', torch.nn.CrossEntropyLoss (reduction = 'none') (input.double(), target).mean())
print ('sum:', torch.nn.CrossEntropyLoss (reduction = 'sum') (input.double(), target) / target.numel())

And here is its output:

1.7.1
nBatch = 128
nClass = 2
height = 768
width = 768
big, float:
mean: tensor(2.7632)
none: tensor(0.9028)
sum: tensor(0.6140)
big, cuda:
mean: tensor(0.9028, device='cuda:0')
none: tensor(0.9028, device='cuda:0')
sum: tensor(0.9028, device='cuda:0')
big, double:
mean: tensor(0.9028, dtype=torch.float64)
none: tensor(0.9028, dtype=torch.float64)
sum: tensor(0.9028, dtype=torch.float64)
nBatch = 8
nClass = 2
height = 16
width = 16
small, float:
mean: tensor(0.9148)
none: tensor(0.9148)
sum: tensor(0.9148)
small, cuda:
mean: tensor(0.9148, device='cuda:0')
none: tensor(0.9148, device='cuda:0')
sum: tensor(0.9148, device='cuda:0')
small, double:
mean: tensor(0.9148, dtype=torch.float64)
none: tensor(0.9148, dtype=torch.float64)
sum: tensor(0.9148, dtype=torch.float64)

Best.

K. Frank

1 Like

Yes, this looks indeed like a bug.
Thank you very much for creating this great code snippet and @Dominika_Basaj thanks a lot for reporting this issue.
I’ll create an issue on GitHub and will link it here.

EDIT: Issue created here

Update: Fix is incoming. The root cause seems to be the missing accumulation dtype for the CPU path.
CC @KFrank @Dominika_Basaj

Thank you for taking care of it @ptrblck :slight_smile:

To not steal someone’s thunder: I’ve just forwarded the issue and zsef123 created the proposed fix :wink:

Hi @ptrblck!

Just to confirm, I see that the fix is in in the latest nighty. Here is the
result of my above-posted test script:

1.9.0.dev20210416
nBatch = 128
nClass = 2
height = 768
width = 768
big, float:
mean: tensor(0.9028)
none: tensor(0.9028)
sum: tensor(0.9028)
big, cuda:
mean: tensor(0.9028, device='cuda:0')
none: tensor(0.9028, device='cuda:0')
sum: tensor(0.9028, device='cuda:0')
big, double:
mean: tensor(0.9028, dtype=torch.float64)
none: tensor(0.9028, dtype=torch.float64)
sum: tensor(0.9028, dtype=torch.float64)
nBatch = 8
nClass = 2
height = 16
width = 16
small, float:
mean: tensor(0.9148)
none: tensor(0.9148)
sum: tensor(0.9148)
small, cuda:
mean: tensor(0.9148, device='cuda:0')
none: tensor(0.9148, device='cuda:0')
sum: tensor(0.9148, device='cuda:0')
small, double:
mean: tensor(0.9148, dtype=torch.float64)
none: tensor(0.9148, dtype=torch.float64)
sum: tensor(0.9148, dtype=torch.float64)

Best.

K. Frank

1 Like