Multidimensional BCEWithLogitsLoss

Hello,
I am trying to use BCEWithLogitsLoss() with additional dimensions but it crashes. The error messages are different in cases the additional axis has a size of 1 or 2:

batch_size = 8
groups = 1
num_classes = 10
pos_weight = (num_classes-1.0)*torch.ones([num_classes])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
x = torch.rand(batch_size, num_classes, groups)
print(x.shape)
# print(x)
y = torch.empty(batch_size, dtype=torch.long).random_(num_classes)
print(y.shape)
y = y.unsqueeze(1).expand(-1,groups)
print(y.shape)
y = F.one_hot(y, num_classes).type_as(x).transpose(1, 2)
print(y.shape)
# print(y)
criterion(x,y)
torch.Size([8, 10, 1])
torch.Size([8])
torch.Size([8, 1])
torch.Size([8, 10, 1])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-36-5314aeccd9bc> in <module>
     14 print(y.shape)
     15 # print(y)
---> 16 criterion(x,y)

/srv/storage/irim@storage1.lille.grid5000.fr/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/srv/storage/irim@storage1.lille.grid5000.fr/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    599                                                   self.weight,
    600                                                   pos_weight=self.pos_weight,
--> 601                                                   reduction=self.reduction)
    602 
    603 

/srv/storage/irim@storage1.lille.grid5000.fr/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2124         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2125 
-> 2126     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
   2127 
   2128 

RuntimeError: output with shape [8, 10, 1] doesn't match the broadcast shape [8, 10, 10]

batch_size = 8
groups = 2
num_classes = 10
pos_weight = (num_classes-1.0)*torch.ones([num_classes])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
x = torch.rand(batch_size, num_classes, groups)
print(x.shape)
# print(x)
y = torch.empty(batch_size, dtype=torch.long).random_(num_classes)
print(y.shape)
y = y.unsqueeze(1).expand(-1,groups)
print(y.shape)
y = F.one_hot(y, num_classes).type_as(x).transpose(1, 2)
print(y.shape)
# print(y)
criterion(x,y)

torch.Size([8, 10, 2])
torch.Size([8])
torch.Size([8, 2])
torch.Size([8, 10, 2])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-37-f22703bfb833> in <module>
     14 print(y.shape)
     15 # print(y)
---> 16 criterion(x,y)

/srv/storage/irim@storage1.lille.grid5000.fr/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/srv/storage/irim@storage1.lille.grid5000.fr/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    599                                                   self.weight,
    600                                                   pos_weight=self.pos_weight,
--> 601                                                   reduction=self.reduction)
    602 
    603 

/srv/storage/irim@storage1.lille.grid5000.fr/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2124         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2125 
-> 2126     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
   2127 
   2128 

RuntimeError: The size of tensor a (10) must match the size of tensor b (2) at non-singleton dimension 2

In both cases, I am unable to understand the error message. The x and y tensors have the same shape and the documentation says that there can be any number of additional dimensions after the batch one. Can anyone tell me what I got wrong?
Thanks.

Best regards,

Georges Quénot.

Hello Georges!

Your use of BCEWithLogitsLoss looks correct to me. I am unable
to reproduce the errors you are getting using pytorch version 0.3.0.
Note. 0.3.0 does not support the pos_weight argument to the
BCEWithLogitsLoss constructor, so maybe there’s a hint there.

Could you tell us what version of pytorch you are using?

Two things you might try:

First, try leaving out the pos_weight argument when you construct
criterion:

criterion = torch.nn.BCEWithLogitsLoss()

Second, does it work for you if you don’t use “additional dimensions?”
That is, try it without the groups dimension (no groups at all, not just
groups = 1).

Here is my version of your script, tweaked to run with pytorch 0.3.0,
hence no pos_weight (and no built-in one_hot()):

import torch
torch.__version__
torch.manual_seed (2020)

batch_size = 8
groups = 1
num_classes = 10

# pos_weight = (num_classes-1.0)*torch.ones([num_classes])
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion = torch.nn.BCEWithLogitsLoss()

x = torch.rand (batch_size, num_classes, groups)
print (x.shape)

y = torch.zeros (batch_size).long().random_ (num_classes)
print (y.shape)
y = y.unsqueeze (1).expand (-1, groups)
print (y.shape)
# one-hot y
y = torch.zeros (batch_size, num_classes, groups).scatter_ (1, y.unsqueeze (1), 1.0)
print (y.shape)

print (x[0])
print (y[0])
criterion (x, y)

groups = 2

x = torch.rand (batch_size, num_classes, groups)
print (x.shape)

y = torch.zeros (batch_size).long().random_ (num_classes)
print (y.shape)
y = y.unsqueeze (1).expand (-1, groups)
print (y.shape)
# one-hot y
y = torch.zeros (batch_size, num_classes, groups).scatter_ (1, y.unsqueeze (1), 1.0)
print (y.shape)

print (x[0])
print (y[0])
criterion (x, y)

And here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000001A55AB46630>
>>>
>>> batch_size = 8
>>> groups = 1
>>> num_classes = 10
>>>
>>> # pos_weight = (num_classes-1.0)*torch.ones([num_classes])
... # criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
... criterion = torch.nn.BCEWithLogitsLoss()
>>>
>>> x = torch.rand (batch_size, num_classes, groups)
>>> print (x.shape)
torch.Size([8, 10, 1])
>>>
>>> y = torch.zeros (batch_size).long().random_ (num_classes)
>>> print (y.shape)
torch.Size([8])
>>> y = y.unsqueeze (1).expand (-1, groups)
>>> print (y.shape)
torch.Size([8, 1])
>>> # one-hot y
... y = torch.zeros (batch_size, num_classes, groups).scatter_ (1, y.unsqueeze (1), 1.0)
>>> print (y.shape)
torch.Size([8, 10, 1])
>>>
>>> print (x[0])

 0.4869
 0.1052
 0.5883
 0.1161
 0.4949
 0.2824
 0.5899
 0.8105
 0.2512
 0.6307
[torch.FloatTensor of size 10x1]

>>> print (y[0])

    1
    0
    0
    0
    0
    0
    0
    0
    0
    0
[torch.FloatTensor of size 10x1]

>>> criterion (x, y)
0.9552074257284403
>>>
>>> groups = 2
>>>
>>> x = torch.rand (batch_size, num_classes, groups)
>>> print (x.shape)
torch.Size([8, 10, 2])
>>>
>>> y = torch.zeros (batch_size).long().random_ (num_classes)
>>> print (y.shape)
torch.Size([8])
>>> y = y.unsqueeze (1).expand (-1, groups)
>>> print (y.shape)
torch.Size([8, 2])
>>> # one-hot y
... y = torch.zeros (batch_size, num_classes, groups).scatter_ (1, y.unsqueeze (1), 1.0)
>>> print (y.shape)
torch.Size([8, 10, 2])
>>>
>>> print (x[0])

 0.0668  0.2524
 0.9783  0.9895
 0.0449  0.4335
 0.2581  0.3636
 0.1338  0.0855
 0.9679  0.7406
 0.4356  0.2211
 0.9607  0.8944
 0.3491  0.3937
 0.3645  0.1764
[torch.FloatTensor of size 10x2]

>>> print (y[0])

    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
    1     1
    0     0
    0     0
[torch.FloatTensor of size 10x2]

>>> criterion (x, y)
0.9362296428531408

Maybe someone could try this with pos_weight on pytorch 1.x (or
whenever pos_weight was added).

Best.

K. Frank

Dear K. Franck,

Thanks for you answer.

I am using pytorch version 1.4.0

No error happens when I don’t use pos_weight and an additional dimension simultaneously so it was a good guess to suspect an interaction.

Both:

import torch
import torch.nn.functional as F
print(torch.__version__)
batch_size = 8
groups = 2
num_classes = 10
print(pos_weight)
criterion = nn.BCEWithLogitsLoss()
x = torch.rand(batch_size, num_classes, groups)
print(x.shape)
# print(x)
y = torch.empty(batch_size, dtype=torch.long).random_(num_classes)
print(y.shape)
y = y.unsqueeze(1).expand(-1,groups)
print(y.shape)
y = F.one_hot(y, num_classes).type_as(x).transpose(1, 2)
print(y.shape)
# print(y)
criterion(x,y)

and

import torch
import torch.nn.functional as F
print(torch.__version__)
batch_size = 8
groups = 2
num_classes = 10
pos_weight = (num_classes-1.0)*torch.ones([num_classes])
print(pos_weight)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
x = torch.rand(batch_size, num_classes)
print(x.shape)
# print(x)
y = torch.empty(batch_size, dtype=torch.long).random_(num_classes)
print(y.shape)
y = F.one_hot(y, num_classes).type_as(x)
print(y.shape)
# print(y)
criterion(x,y)

work just fine. When used, pos_weight is:
tensor([9., 9., 9., 9., 9., 9., 9., 9., 9., 9.])
which looks like what is expected by the function according to the documentation: pos_weight ( Tensor , optional ) – a weight of positive examples. Must be a vector with length equal to the number of classes.

It happens that the following works:

import torch
import torch.nn.functional as F
print(torch.__version__)
batch_size = 8
groups = 2
num_classes = 10
pos_weight = (num_classes-1.0)*torch.ones([num_classes, groups])
print(pos_weight)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
x = torch.rand(batch_size, num_classes, groups)
print(x.shape)
# print(x)
y = torch.empty(batch_size, dtype=torch.long).random_(num_classes)
print(y.shape)
y = y.unsqueeze(1).expand(-1,groups)
print(y.shape)
y = F.one_hot(y, num_classes).type_as(x).transpose(1, 2)
print(y.shape)
# print(y)
print(criterion(x,y))

with output:

1.4.0
tensor([[9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.],
        [9., 9.]])
torch.Size([8, 10, 2])
torch.Size([8])
torch.Size([8, 2])
torch.Size([8, 10, 2])
tensor(1.3508)

So pos_weight should have the same size as input and target (without the batch axis) but this is not clear in the documentation.

Best regards,

Georges.

Hi Georges!

Thanks. That makes it pretty clear what is going on.

I would call this a bug. It somewhat defeats the purpose of
pos_weight, and, in my view, breaks its semantics. (In any
event, I would say it’s at least a documentation bug.)

Perhaps someone like @albanD could take a look …

Best.

K. Frank