Why the loss decreasing very slowly with BCEWithLogitsLoss() and not predicting correct values

I am working on a toy dataset to play with. I am trying to calculate loss via BCEWithLogitsLoss(), but loss is decreasing very slowly. And prediction giving by Neural network also is not correct.

model = nn.Linear(1,1)
input_tensor = th.tensor([[5.8],[6.0],[5.5],[4.5],[4.1],[3.5]],requires_grad=True)
target_tensor = th.tensor([[1.0],[1],[1],[0],[0],[0]],requires_grad=False)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr =0.1)
for i in range(500):
    optimizer.zero_grad()
    predict_tensor = model(input_tensor)
    loss = criterion(predict_tensor,target_tensor)
    loss.backward()
    optimizer.step()
    print(loss.data)
Loss values:
tensor(1.7171)
tensor(1.3382)
tensor(1.0275)
tensor(0.8316)
tensor(0.7524)
tensor(0.7327)
tensor(0.7287)
tensor(0.7274)
tensor(0.7265)
tensor(0.7257)
tensor(0.7250)
tensor(0.7242)
tensor(0.7234)
tensor(0.7226)
tensor(0.7219)
tensor(0.7211)
tensor(0.7203)
tensor(0.7195)
tensor(0.7188)
tensor(0.7180)
tensor(0.7172)
tensor(0.7165)
tensor(0.7157)
tensor(0.7149)
tensor(0.7142)
tensor(0.7134)
tensor(0.7127)
tensor(0.7119)
tensor(0.7111)
tensor(0.7104)
tensor(0.7096)
tensor(0.7089)
tensor(0.7081)
tensor(0.7074)
tensor(0.7066)
tensor(0.7059)
tensor(0.7051)
tensor(0.7044)
tensor(0.7036)
tensor(0.7029)
tensor(0.7022)
tensor(0.7014)
tensor(0.7007)
tensor(0.6999)
tensor(0.6992)
tensor(0.6985)
tensor(0.6977)
tensor(0.6970)
tensor(0.6963)
tensor(0.6955)
tensor(0.6948)
tensor(0.6941)
tensor(0.6933)
tensor(0.6926)
tensor(0.6919)
tensor(0.6912)
tensor(0.6904)
tensor(0.6897)
tensor(0.6890)
tensor(0.6883)
tensor(0.6875)
tensor(0.6868)
tensor(0.6861)
tensor(0.6854)
tensor(0.6847)
tensor(0.6840)
tensor(0.6833)
tensor(0.6825)
tensor(0.6818)
tensor(0.6811)
tensor(0.6804)
tensor(0.6797)
tensor(0.6790)
tensor(0.6783)
tensor(0.6776)
tensor(0.6769)
tensor(0.6762)
tensor(0.6755)
tensor(0.6748)
tensor(0.6741)
tensor(0.6734)
tensor(0.6727)
tensor(0.6720)
tensor(0.6713)
tensor(0.6706)
tensor(0.6699)
tensor(0.6693)
tensor(0.6686)
tensor(0.6679)
tensor(0.6672)
tensor(0.6665)
tensor(0.6658)
tensor(0.6651)
tensor(0.6645)
tensor(0.6638)
tensor(0.6631)
tensor(0.6624)
tensor(0.6617)
tensor(0.6611)
tensor(0.6604)
tensor(0.6597)
tensor(0.6591)
tensor(0.6584)
tensor(0.6577)
tensor(0.6570)
tensor(0.6564)
tensor(0.6557)
tensor(0.6550)
tensor(0.6544)
tensor(0.6537)
tensor(0.6530)
tensor(0.6524)
tensor(0.6517)
tensor(0.6511)
tensor(0.6504)
tensor(0.6497)
tensor(0.6491)
tensor(0.6484)
tensor(0.6478)
tensor(0.6471)
tensor(0.6465)
tensor(0.6458)
tensor(0.6452)
tensor(0.6445)
tensor(0.6439)
tensor(0.6432)
tensor(0.6426)
tensor(0.6419)
tensor(0.6413)
tensor(0.6407)
tensor(0.6400)
tensor(0.6394)
tensor(0.6387)
tensor(0.6381)
tensor(0.6375)
tensor(0.6368)
tensor(0.6362)
tensor(0.6355)
tensor(0.6349)
tensor(0.6343)
tensor(0.6336)
tensor(0.6330)
tensor(0.6324)
tensor(0.6318)
tensor(0.6311)
tensor(0.6305)
tensor(0.6299)
tensor(0.6293)
tensor(0.6286)
tensor(0.6280)
tensor(0.6274)
tensor(0.6268)
tensor(0.6262)
tensor(0.6255)
tensor(0.6249)
tensor(0.6243)
tensor(0.6237)
tensor(0.6231)
tensor(0.6225)
tensor(0.6218)
tensor(0.6212)
tensor(0.6206)
tensor(0.6200)
tensor(0.6194)
tensor(0.6188)
tensor(0.6182)
tensor(0.6176)
tensor(0.6170)
tensor(0.6164)
tensor(0.6158)
tensor(0.6152)
tensor(0.6146)
tensor(0.6140)
tensor(0.6134)
tensor(0.6128)
tensor(0.6122)
tensor(0.6116)
tensor(0.6110)
tensor(0.6104)
tensor(0.6098)
tensor(0.6092)
tensor(0.6086)
tensor(0.6080)
tensor(0.6074)
tensor(0.6069)
tensor(0.6063)
tensor(0.6057)
tensor(0.6051)
tensor(0.6045)
tensor(0.6039)
tensor(0.6033)
tensor(0.6028)
tensor(0.6022)
tensor(0.6016)
tensor(0.6010)
tensor(0.6005)
tensor(0.5999)
tensor(0.5993)
tensor(0.5987)
tensor(0.5982)
tensor(0.5976)
tensor(0.5970)
tensor(0.5964)
tensor(0.5959)
tensor(0.5953)
tensor(0.5947)
tensor(0.5942)
tensor(0.5936)
tensor(0.5930)
tensor(0.5925)
tensor(0.5919)
tensor(0.5913)
tensor(0.5908)
tensor(0.5902)
tensor(0.5897)
tensor(0.5891)
tensor(0.5885)
tensor(0.5880)
tensor(0.5874)
tensor(0.5869)
tensor(0.5863)
tensor(0.5858)
tensor(0.5852)
tensor(0.5847)
tensor(0.5841)
tensor(0.5836)
tensor(0.5830)
tensor(0.5825)
tensor(0.5819)
tensor(0.5814)
tensor(0.5808)
tensor(0.5803)
tensor(0.5797)
tensor(0.5792)
tensor(0.5786)
tensor(0.5781)
tensor(0.5776)
tensor(0.5770)
tensor(0.5765)
tensor(0.5759)
tensor(0.5754)
tensor(0.5749)
tensor(0.5743)
tensor(0.5738)
tensor(0.5733)
tensor(0.5727)
tensor(0.5722)
tensor(0.5717)
tensor(0.5711)
tensor(0.5706)
tensor(0.5701)
tensor(0.5696)
tensor(0.5690)
tensor(0.5685)
tensor(0.5680)
tensor(0.5675)
tensor(0.5669)
tensor(0.5664)
tensor(0.5659)
tensor(0.5654)
tensor(0.5648)
tensor(0.5643)
tensor(0.5638)
tensor(0.5633)
tensor(0.5628)
tensor(0.5623)
tensor(0.5617)
tensor(0.5612)
tensor(0.5607)
tensor(0.5602)
tensor(0.5597)
tensor(0.5592)
tensor(0.5587)
tensor(0.5582)
tensor(0.5576)
tensor(0.5571)
tensor(0.5566)
tensor(0.5561)
tensor(0.5556)
tensor(0.5551)
tensor(0.5546)
tensor(0.5541)
tensor(0.5536)
tensor(0.5531)
tensor(0.5526)
tensor(0.5521)
tensor(0.5516)
tensor(0.5511)
tensor(0.5506)
tensor(0.5501)
tensor(0.5496)
tensor(0.5491)
tensor(0.5486)
tensor(0.5481)
tensor(0.5476)
tensor(0.5472)
tensor(0.5467)
tensor(0.5462)
tensor(0.5457)
tensor(0.5452)
tensor(0.5447)
tensor(0.5442)
tensor(0.5437)
tensor(0.5432)
tensor(0.5428)
tensor(0.5423)
tensor(0.5418)
tensor(0.5413)
tensor(0.5408)
tensor(0.5403)
tensor(0.5399)
tensor(0.5394)
tensor(0.5389)
tensor(0.5384)
tensor(0.5380)
tensor(0.5375)
tensor(0.5370)
tensor(0.5365)
tensor(0.5361)
tensor(0.5356)
tensor(0.5351)
tensor(0.5346)
tensor(0.5342)
tensor(0.5337)
tensor(0.5332)
tensor(0.5328)
tensor(0.5323)
tensor(0.5318)
tensor(0.5313)
tensor(0.5309)
tensor(0.5304)
tensor(0.5300)
tensor(0.5295)
tensor(0.5290)
tensor(0.5286)
tensor(0.5281)
tensor(0.5276)
tensor(0.5272)
tensor(0.5267)
tensor(0.5263)
tensor(0.5258)
tensor(0.5253)
tensor(0.5249)
tensor(0.5244)
tensor(0.5240)
tensor(0.5235)
tensor(0.5231)
tensor(0.5226)
tensor(0.5222)
tensor(0.5217)
tensor(0.5213)
tensor(0.5208)
tensor(0.5204)
tensor(0.5199)
tensor(0.5195)
tensor(0.5190)
tensor(0.5186)
tensor(0.5181)
tensor(0.5177)
tensor(0.5172)
tensor(0.5168)
tensor(0.5163)
tensor(0.5159)
tensor(0.5155)
tensor(0.5150)
tensor(0.5146)
tensor(0.5141)
tensor(0.5137)
tensor(0.5133)
tensor(0.5128)
tensor(0.5124)
tensor(0.5120)
tensor(0.5115)
tensor(0.5111)
tensor(0.5106)
tensor(0.5102)
tensor(0.5098)
tensor(0.5093)
tensor(0.5089)
tensor(0.5085)
tensor(0.5081)
tensor(0.5076)
tensor(0.5072)
tensor(0.5068)
tensor(0.5063)
tensor(0.5059)
tensor(0.5055)
tensor(0.5051)
tensor(0.5046)
tensor(0.5042)
tensor(0.5038)
tensor(0.5034)
tensor(0.5029)
tensor(0.5025)
tensor(0.5021)
tensor(0.5017)
tensor(0.5013)
tensor(0.5008)
tensor(0.5004)
tensor(0.5000)
tensor(0.4996)
tensor(0.4992)
tensor(0.4988)
tensor(0.4983)
tensor(0.4979)
tensor(0.4975)
tensor(0.4971)
tensor(0.4967)
tensor(0.4963)
tensor(0.4959)
tensor(0.4954)
tensor(0.4950)
tensor(0.4946)
tensor(0.4942)
tensor(0.4938)
tensor(0.4934)
tensor(0.4930)
tensor(0.4926)
tensor(0.4922)
tensor(0.4918)
tensor(0.4914)
tensor(0.4910)
tensor(0.4906)
tensor(0.4902)
tensor(0.4898)
tensor(0.4894)
tensor(0.4890)
tensor(0.4886)
tensor(0.4882)
tensor(0.4878)
tensor(0.4874)
tensor(0.4870)
tensor(0.4866)
tensor(0.4862)
tensor(0.4858)
tensor(0.4854)
tensor(0.4850)
tensor(0.4846)
tensor(0.4842)
tensor(0.4838)
tensor(0.4834)
tensor(0.4830)
tensor(0.4826)
tensor(0.4822)
tensor(0.4818)
tensor(0.4814)
tensor(0.4811)
tensor(0.4807)
tensor(0.4803)
tensor(0.4799)
tensor(0.4795)
tensor(0.4791)
tensor(0.4787)
tensor(0.4784)
tensor(0.4780)
tensor(0.4776)
tensor(0.4772)
tensor(0.4768)
tensor(0.4764)
tensor(0.4761)
tensor(0.4757)
tensor(0.4753)
tensor(0.4749)
tensor(0.4745)
tensor(0.4742)
tensor(0.4738)
tensor(0.4734)
tensor(0.4730)
tensor(0.4726)
tensor(0.4723)
tensor(0.4719)
tensor(0.4715)
tensor(0.4711)
tensor(0.4708)
tensor(0.4704)
tensor(0.4700)
tensor(0.4697)
tensor(0.4693)
tensor(0.4689)
tensor(0.4685)
tensor(0.4682)
tensor(0.4678)
tensor(0.4674)
tensor(0.4671)
tensor(0.4667)
tensor(0.4663)
tensor(0.4660)
tensor(0.4656)
tensor(0.4652)
tensor(0.4649)
tensor(0.4645)
tensor(0.4641)
tensor(0.4638)
tensor(0.4634)
tensor(0.4631)
tensor(0.4627)
tensor(0.4623)
tensor(0.4620)
tensor(0.4616)
tensor(0.4613)

print(model(th.tensor([80.5]))) gives tensor([139.4498], grad_fn=<AddBackward0>)
My model is giving logits as outputs and I want it to give me probabilities but if I add an activation function at the end, BCEWithLogitsLoss() would mess up because it expects logits as inputs.
Do troubleshooting with Google colab notebook: https://colab.research.google.com/drive/1WjCcSv5nVXf-zD1mCEl17h5jp7V2Pooz

1 Like

Hi Mnauf!

I suspect that you are misunderstanding how to interpret the
predictions made by this network.

First, you are using, as you say, BCEWithLogitsLoss. Therefore you
are training your predictions to be “logits.” These are “raw scores,”
if you will, that are real numbers ranging from -infinity to +infinity.
Values less than 0 predict class “0” and values greater than 0
predict class “1”.

(When pumped though a sigmoid function, they become predicted
probabilities of the sample in question being in the “1” class. We
generally convert that to a non-probabilistic prediction by saying
P < 0.5 → class “0”, and P > 0.5 → class “1”.)

Second, your model is a simple (one-dimensional) linear function.
Therefore it can’t cluster predictions together – it can only get the
boundary between class “0” and class “1” right. (Because of this,
you will not ever be able to drive your loss to zero, even if your
prediction accuracy is perfect.) From your six data points that
boundary is somewhere around 5.0.

Anyway, your model works for me.

Note, I’ve run the below test using pytorch version 0.3.0, so I had
to tweak your code a little bit.

Here is my complete script:

import torch
print (torch.__version__)

torch.manual_seed (2019)

model = torch.nn.Linear (1, 1)
print ('model.weight =', model.weight.data[0][0])
print ('model.bias =', model.bias.data[0])

input_tensor = torch.autograd.Variable (torch.Tensor([[5.8],[6.0],[5.5],[4.5],[4.1],[3.5]]))
target_tensor = torch.autograd.Variable (torch.Tensor([[1.0],[1],[1],[0],[0],[0]]))
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr =0.1)

for  i in range (500):
    optimizer.zero_grad()
    predict_tensor = model(input_tensor)
    loss = criterion(predict_tensor,target_tensor)
    loss.backward()
    optimizer.step()
    if  (i + 1) % 100 == 0:
        print (loss.data[0])
        pd = predict_tensor.data
        print (pd[0][0], pd[1][0], pd[2][0], pd[3][0], pd[4][0], pd[5][0])

print ('model.weight =', model.weight.data[0][0])
print ('model.bias =', model.bias.data[0])

And here is the complete output:

>>> import torch
>>> print (torch.__version__)
0.3.0b0+591e73e
>>>
>>> torch.manual_seed (2019)
<torch._C.Generator object at 0x000001FFD89F5630>
>>>
>>> model = torch.nn.Linear (1, 1)
>>> print ('model.weight =', model.weight.data[0][0])
model.weight = -0.41710567474365234
>>> print ('model.bias =', model.bias.data[0])
model.bias = 0.1750929355621338
>>>
>>> input_tensor = torch.autograd.Variable (torch.Tensor([[5.8],[6.0],[5.5],[4.5],[4.1],[3.5]]))
>>> target_tensor = torch.autograd.Variable (torch.Tensor([[1.0],[1],[1],[0],[0],[0]]))
>>> criterion = torch.nn.BCEWithLogitsLoss()
>>> optimizer = torch.optim.SGD(model.parameters(), lr =0.1)
>>>
>>> for  i in range (500):
...     optimizer.zero_grad()
...     predict_tensor = model(input_tensor)
...     loss = criterion(predict_tensor,target_tensor)
...     loss.backward()
...     optimizer.step()
...     if  (i + 1) % 100 == 0:
...         print (loss.data[0])
...         pd = predict_tensor.data
...         print (pd[0][0], pd[1][0], pd[2][0], pd[3][0], pd[4][0], pd[5][0])
...
0.6329450607299805
0.4733131527900696 0.5083065032958984 0.42082297801971436 0.2458558976650238 0.17586903274059296 0.07088879495859146
0.5748059749603271
0.5860824584960938 0.650752604007721 0.489077091217041 0.16572602093219757 0.03638556972146034 -0.15762503445148468
0.5252929329872131
0.6919865608215332 0.7841049432754517 0.5538088083267212 0.09321644902229309 -0.09102052450180054 -0.3673758804798126
0.48296624422073364
0.7910601496696472 0.9085955619812012 0.6147568225860596 0.027079343795776367 -0.20799170434474945 -0.5605981349945068
0.4465940296649933
0.8835896849632263 1.024709701538086 0.6719093322753906 -0.03369140625 -0.3159317672252655 -0.7392921447753906
>>> print ('model.weight =', model.weight.data[0][0])
model.weight = 0.7067369818687439
>>> print ('model.bias =', model.bias.data[0])
model.bias = -3.2145912647247314

The loss goes down systematically (but, as noted above, doesn’t
go to zero). And at the end of the run the prediction accuracy is
perfect on your set of six samples (with the predictions understood
as described above).

Good luck.

K. Frank

[Edit:]

Please let me correct an incorrect statement I made. I said that
you can’t drive the loss all the way to zero, but in fact you can.
As the “weight” in the model – the multiplicative factor in the linear
function – becomes larger and larger, the logits predicted by the
model get pushed out towards -infinity and +infinity. This will cause
the sigmoid (that is implicit in BCEWithLogitsLoss) to saturate at
0 and 1, so the predictions will become (increasing close to) exactly
correct (provided the bias is adjusted according, which the training
algorithm does), and the loss approaches zero.

Here are the last twenty loss values obtained by running Mnauf’s
training loop for 10,000 iterations:

0.06364566087722778
0.06364051252603531
0.06363534182310104
0.06363023072481155
0.06362506002187729
0.06361991167068481
0.06361475586891174
0.06360962986946106
0.06360447406768799
0.06359934061765671
0.06359419226646423
0.06358903646469116
0.06358392536640167
0.0635787770152092
0.06357365846633911
0.06356850266456604
0.06356338411569595
0.06355825066566467
0.06355313956737518
0.0635480210185051

So the loss does approach zero, although very slowly. Note, as the
sigmoid saturates, its gradients go to zero, so (with a fixed learning
rate) the training slows way down.

3 Likes