Runtime Error for nn.BCEWithLogitsLoss() with pos_weights

Hi, I tend to run my binary segmentation codes with loss function nn.BCEWithLogitsLoss(). The positive and negative target is not balanced so I tend to use pos_weight. However, it returns the runtime error which I don’t know why.

model = model.cuda()
w_pos = torch.tensor([4.538])
criterion = nn.BCEWithLogitsLoss(pos_weight = w_pos)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

im = Variable(data[0].cuda())
labels = Variable(data[1].cuda())
out = model(im)
out1 = torch.squeeze(out)
loss = criterion(out1, labels.float())

Traceback (most recent call last):
File “D:/DSC/Planet_test/”, line 317, in
loss = criterion(out1, labels.float())
File “C:\Users\he425\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\”, line 547, in call
result = self.forward(*input, **kwargs)
File “C:\Users\he425\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\”, line 601, in forward
File “C:\Users\he425\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\”, line 2100, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: expected device cpu and dtype Float but got device cuda:0 and dtype Float

Is there any solutions for this issue?

Move your w_pos to the same device as your data and target and it should work. :wink:

Hi, thanks for your friendly reply!

The problem is the error said it expected data from CPU, but I gave it from my GPU. You can see from my codes that my model, input image, and groundtruth labels are all in GPU by .cuda(), but w_pos is not. So I think it might be the problems from those 3 variables that are already in GPU?

I tried w_pos = torch.tensor([4.538]).cuda() before asking this question, but it showed the similar error. I think probably it was not related to w_pos? Any new ideas?

Thanks again for reply!

I tried the code with every tensor on the GPU but w_pos and got the same error message, thus I assumed this is the error.
Could you post a code snippet to reproduce this error?

I think I solved the problem by your suggestion! Thank you very much!!!