Predictions are all zeros in binary classification problem

Sorry I’m a newbie.
Trying to solve a classification problem with pytorch. But the predictions always turns into all zeroes.
Please can someone help.

num_classes = 2
input_size = 3
hidden_size = 4
num_epochs = 5

class Net(Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()

        self.l1 = Linear(input_size, hidden_size)
        self.l2 = Linear(hidden_size, num_classes)
        self.sig = Tanh()

    def forward(self, x):
        out = self.l1(x)
        out = self.sig(out)
        out = self.l2(out)
        return F.log_softmax(out)



nn = Net(input_size, hidden_size, num_classes)
optimizer = optim.Adam(nn.parameters(), lr=0.001)

for epoch in range(100):
    output = nn(x)
    _, y_pred = output.max(1)

    print('target',str(y[:4].view(1,-1)).split('\n')[1])
    print('predct',str(y_pred[:4].view(1,-1)).split('\n')[1])

    loss = F.nll_loss(output,y)
    print(loss.data[0])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

for i,j in zip(y[:20],y_pred[:20]):
    print(i.data[0],j.data[0])

That’s strange. I ran your code with random input and targets and the model seems to learn.

After some epochs I get:

('target', ' 1  1  0  0')
('predct', ' 1  1  0  0')
Loss: 0.496161192656

What kind of input data are you using? Is the loss moving at all?

1 Like

Im extracting data from a csv file with pandas and converted it to pytorch tensors.

Y = torch.from_numpy(np.array(cdf[‘class’][:4000], dtype=np.long))
X = torch.from_numpy(np.column_stack((np.array(cdf[‘cap-shape’][:4000], dtype=np.float), np.array(cdf[‘cap-surface’][:4000], dtype=np.float), np.array(cdf[‘cap-color’][:4000], dtype=np.float))))
x = Variable(X, requires_grad=True).float()
y = Variable(Y)

Looks fine. Usually you don’t need requires_grad=True for your input, but that shouldn’t make a difference.
Could you post the data file?
What does your loss do? Is it constant or is it shaking a bit?

The Loss is decreasing slowly. But the predictions doesn’t change.

[Command: python3 -u /root/TorchTest/nn.py]
/root/TorchTest/nn.py:56: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  return F.log_softmax(out)
target  1  0  0  1
predct  1  1  1  1
1.072535753250122
target  1  0  0  1
predct  1  1  1  1
1.066185474395752
target  1  0  0  1
predct  1  1  1  1
1.0599080324172974
target  1  0  0  1
predct  1  1  1  1
1.0536561012268066
target  1  0  0  1
predct  1  1  1  1
1.0474143028259277
target  1  0  0  1
predct  1  1  1  1
1.041246771812439
target  1  0  0  1
predct  1  1  1  1
1.0351084470748901
target  1  0  0  1
predct  1  1  1  1
1.0290038585662842
target  1  0  0  1
predct  1  1  1  1
1.0229482650756836
target  1  0  0  1
predct  1  1  1  1
1.0169259309768677
target  1  0  0  1
predct  1  1  1  1
1.0109413862228394
target  1  0  0  1
predct  1  1  1  1
1.0050331354141235
target  1  0  0  1
predct  1  1  1  1
0.9991320967674255
target  1  0  0  1
predct  1  1  1  1
0.993306577205658
target  1  0  0  1
predct  1  1  1  1
0.9875351190567017
target  1  0  0  1
predct  1  1  1  1
0.9817762970924377
target  1  0  0  1
predct  1  1  1  1
0.9761199951171875
target  1  0  0  1
predct  1  1  1  1
0.9704711437225342
target  1  0  0  1
predct  1  1  1  1
0.9648805260658264
target  1  0  0  1
predct  1  1  1  1
0.9593492150306702
target  1  0  0  1
predct  1  1  1  1
0.9538455009460449
target  1  0  0  1
predct  1  1  1  1
0.9484182596206665
target  1  0  0  1
predct  1  1  1  1
0.943033754825592
target  1  0  0  1
predct  1  1  1  1
0.9377087354660034
target  1  0  0  1
predct  1  1  1  1
0.9324311017990112
target  1  0  0  1
predct  1  1  1  1
0.9272052645683289
target  1  0  0  1
predct  1  1  1  1
0.9220643639564514
target  1  0  0  1
predct  1  1  1  1
0.9169411063194275
target  1  0  0  1
predct  1  1  1  1
0.9118647575378418
target  1  0  0  1
predct  1  1  1  1
0.9068688154220581
target  1  0  0  1
predct  1  1  1  1
0.9019254446029663
target  1  0  0  1
predct  1  1  1  1
0.897035539150238
target  1  0  0  1
predct  1  1  1  1
0.8921868205070496
target  1  0  0  1
predct  1  1  1  1
0.8873911499977112
target  1  0  0  1
predct  1  1  1  1
0.882676899433136
target  1  0  0  1
predct  1  1  1  1
0.8779977560043335
target  1  0  0  1
predct  1  1  1  1
0.8733696341514587
target  1  0  0  1
predct  1  1  1  1
0.8688199520111084
target  1  0  0  1
predct  1  1  1  1
0.86427903175354
target  1  0  0  1
predct  1  1  1  1
0.8598242998123169
target  1  0  0  1
predct  1  1  1  1
0.8554247617721558
target  1  0  0  1
predct  1  1  1  1
0.8510544896125793
target  1  0  0  1
predct  1  1  1  1
0.8467731475830078
target  1  0  0  1
predct  1  1  1  1
0.842523992061615
target  1  0  0  1
predct  1  1  1  1
0.838318943977356
target  1  0  0  1
predct  1  1  1  1
0.8341671228408813
target  1  0  0  1
predct  1  1  1  1
0.8300833702087402
target  1  0  0  1
predct  1  1  1  1
0.8260287642478943
target  1  0  0  1
predct  1  1  1  1
0.8220428824424744
target  1  0  0  1
predct  1  1  1  1
0.8180790543556213
target  1  0  0  1
predct  1  1  1  1
0.8141903877258301
target  1  0  0  1
predct  1  1  1  1
0.8103442788124084
target  1  0  0  1
predct  1  1  1  1
0.806536853313446
target  1  0  0  1
predct  1  1  1  1
0.80278080701828
target  1  0  0  1
predct  1  1  1  1
0.7990645170211792
target  1  0  0  1
predct  1  1  1  1
0.7954161167144775
target  1  0  0  1
predct  1  1  1  1
0.7917777895927429
target  1  0  0  1
predct  1  1  1  1
0.7882110476493835
target  1  0  0  1
predct  0  1  1  1
0.7846689820289612
target  1  0  0  1
predct  0  1  1  1
0.7811885476112366
target  1  0  0  1
predct  0  1  1  1
0.777743935585022
target  1  0  0  1
predct  0  1  1  1
0.7743515968322754
target  1  0  0  1
predct  0  1  1  1
0.7709869146347046
target  1  0  0  1
predct  0  1  1  1
0.7676637768745422
target  1  0  0  1
predct  0  1  1  1
0.7643994092941284
target  1  0  0  1
predct  0  1  1  1
0.7611634731292725
target  1  0  0  1
predct  0  1  1  1
0.7579619288444519
target  1  0  0  1
predct  0  1  1  1
0.7547981142997742
target  1  0  0  1
predct  0  1  1  1
0.7516692876815796
target  1  0  0  1
predct  0  1  1  1
0.7485985159873962
target  1  0  0  1
predct  0  1  1  1
0.745551347732544
target  1  0  0  1
predct  0  1  1  1
0.7425572872161865
target  1  0  0  1
predct  0  1  1  1
0.7395659685134888
target  1  0  0  1
predct  0  1  1  1
0.7366408109664917
target  1  0  0  1
predct  0  1  1  1
0.7337499856948853
target  1  0  0  1
predct  0  1  1  1
0.7308931946754456
target  1  0  0  1
predct  0  1  1  1
0.7280638813972473
target  1  0  0  1
predct  0  1  1  1
0.7252659797668457
target  1  0  0  1
predct  0  1  1  1
0.7225061655044556
target  1  0  0  1
predct  0  1  1  1
0.7197932004928589
target  1  0  0  1
predct  0  1  1  1
0.7170913219451904
target  1  0  0  1
predct  0  1  1  1
0.7144457101821899
target  1  0  0  1
predct  0  1  1  1
0.7118056416511536
target  1  0  0  1
predct  0  1  1  1
0.7092230319976807
target  1  0  0  1
predct  0  1  1  1
0.7066526412963867
target  1  0  0  1
predct  0  1  1  1
0.7041279077529907
target  1  0  0  1
predct  0  1  1  1
0.7016199827194214
target  1  0  0  1
predct  0  1  1  1
0.6991593837738037
target  1  0  0  1
predct  0  1  1  1
0.6967200040817261
target  1  0  0  1
predct  0  1  1  1
0.6943032145500183
target  1  0  0  1
predct  0  1  1  1
0.6919167637825012
target  1  0  0  1
predct  0  1  1  1
0.689564049243927
target  1  0  0  1
predct  0  1  1  1
0.6872462630271912
target  1  0  0  1
predct  0  1  1  1
0.68495112657547
target  1  0  0  1
predct  0  1  1  1
0.6826843619346619
target  1  0  0  1
predct  0  1  1  1
0.6804326176643372
target  1  0  0  1
predct  0  1  1  1
0.6782170534133911
target  1  0  0  1
predct  0  1  1  1
0.67603600025177
target  1  0  0  1
predct  0  1  1  1
0.6738657355308533
target  1  0  0  1
predct  0  1  1  1
0.6717345118522644
target  1  0  0  1
predct  0  1  1  1
0.6696066856384277
target  1  0  0  1
predct  0  1  1  1
0.6675252914428711
target  1  0  0  1
predct  0  1  1  1
0.6654491424560547
target  1  0  0  1
predct  0  1  1  1
0.663428783416748
target  1  0  0  1
predct  0  1  1  1
0.6614077687263489
target  1  0  0  1
predct  0  1  1  1
0.6593979597091675
target  1  0  0  1
predct  0  1  1  1
0.6574442982673645
target  1  0  0  1
predct  0  1  1  1
0.655487060546875
target  1  0  0  1
predct  0  1  1  1
0.6535556316375732
target  1  0  0  1
predct  0  1  1  1
0.6516560912132263
target  1  0  0  1
predct  0  1  1  1
0.6497719287872314
target  1  0  0  1
predct  0  1  1  1
0.6479018330574036
target  1  0  0  1
predct  0  1  1  1
0.6460625529289246
target  1  0  0  1
predct  0  1  1  1
0.644256055355072
target  1  0  0  1
predct  0  1  1  1
0.6424468755722046
target  1  0  0  1
predct  0  1  1  1
0.6406769156455994
target  1  0  0  1
predct  0  1  1  1
0.6389025449752808
target  1  0  0  1
predct  0  1  1  1
0.6371592879295349
target  1  0  0  1
predct  0  1  1  1
0.635432779788971
target  1  0  0  1
predct  0  1  1  1
0.6337351202964783
target  1  0  0  1
predct  0  1  1  1
0.6320592164993286
target  1  0  0  1
predct  0  1  1  1
0.6303884387016296
target  1  0  0  1
predct  0  1  1  1
0.6287346482276917
target  1  0  0  1
predct  0  1  1  1
0.6271059513092041
target  1  0  0  1
predct  0  1  1  1
0.6254917979240417
target  1  0  0  1
predct  0  1  1  1
0.6238887906074524
target  1  0  0  1
predct  0  1  1  1
0.6223136782646179
target  1  0  0  1
predct  0  1  1  1
0.6207422018051147
target  1  0  0  1
predct  0  1  1  1
0.6192011833190918
target  1  0  0  1
predct  0  1  1  1
0.6176725625991821
target  1  0  0  1
predct  0  1  1  0
0.6161558032035828
target  1  0  0  1
predct  0  1  1  0
0.6146573424339294
target  1  0  0  1
predct  0  1  1  0
0.6131836175918579
target  1  0  0  1
predct  0  1  1  0
0.6117121577262878
target  1  0  0  1
predct  0  1  1  0
0.6102636456489563
target  1  0  0  1
predct  0  1  1  0
0.6088240742683411
target  1  0  0  1
predct  0  1  1  0
0.6073999404907227
target  1  0  0  1
predct  0  1  1  0
0.6059949994087219
target  1  0  0  1
predct  0  1  1  0
0.604609489440918
target  1  0  0  1
predct  0  1  1  0
0.6032212972640991
target  1  0  0  1
predct  0  1  1  0
0.6018573641777039
target  1  0  0  1
predct  0  1  1  0
0.6005046367645264
target  1  0  0  1
predct  0  0  1  0
0.5991700291633606
target  1  0  0  1
predct  0  0  1  0
0.5978490710258484
target  1  0  0  1
predct  0  0  1  0
0.5965204834938049
target  1  0  0  1
predct  0  0  1  0
0.5952356457710266
target  1  0  0  1
predct  0  0  1  0
0.593949556350708
target  1  0  0  1
predct  0  0  1  0
0.5926786661148071
target  1  0  0  1
predct  0  0  1  0
0.5914202332496643
target  1  0  0  1
predct  0  0  1  0
0.5901700854301453
target  1  0  0  1
predct  0  0  0  0
0.5889377593994141
target  1  0  0  1
predct  0  0  0  0
0.5877289175987244
target  1  0  0  1
predct  0  0  0  0
0.5865069627761841
target  1  0  0  1
predct  0  0  0  0
0.5853143930435181
target  1  0  0  1
predct  0  0  0  0
0.5841359496116638
target  1  0  0  1
predct  0  0  0  0
0.582938015460968
target  1  0  0  1
predct  0  0  0  0
0.5817910432815552
target  1  0  0  1
predct  0  0  0  0
0.5806416273117065
target  1  0  0  1
predct  0  0  0  0
0.5795055031776428
target  1  0  0  1
predct  0  0  0  0
0.5783639550209045
target  1  0  0  1
predct  0  0  0  0
0.5772513747215271
target  1  0  0  1
predct  0  0  0  0
0.5761514902114868
target  1  0  0  1
predct  0  0  0  0
0.5750468373298645
target  1  0  0  1
predct  0  0  0  0
0.5739665627479553
target  1  0  0  1
predct  0  0  0  0
0.5728904604911804
target  1  0  0  1
predct  0  0  0  0
0.5718275904655457
target  1  0  0  1
predct  0  0  0  0
0.5707767009735107
target  1  0  0  1
predct  0  0  0  0
0.5697274208068848
target  1  0  0  1
predct  0  0  0  0
0.5686996579170227
target  1  0  0  1
predct  0  0  0  0
0.5676725506782532
target  1  0  0  1
predct  0  0  0  0
0.5666656494140625
target  1  0  0  1
predct  0  0  0  0
0.5656634569168091
target  1  0  0  1
predct  0  0  0  0
0.5646677613258362
target  1  0  0  1
predct  0  0  0  0
0.5636837482452393
target  1  0  0  1
predct  0  0  0  0
0.5627092123031616
target  1  0  0  1
predct  0  0  0  0
0.5617409348487854
target  1  0  0  1
predct  0  0  0  0
0.5607889294624329
target  1  0  0  1
predct  0  0  0  0
0.559839129447937
target  1  0  0  1
predct  0  0  0  0
0.5588983297348022
target  1  0  0  1
predct  0  0  0  0
0.5579784512519836
target  1  0  0  1
predct  0  0  0  0
0.5570580363273621
target  1  0  0  1
predct  0  0  0  0
0.5561417937278748
target  1  0  0  1
predct  0  0  0  0
0.5552458167076111
target  1  0  0  1
predct  0  0  0  0
0.5543524622917175
target  1  0  0  1
predct  0  0  0  0
0.5534700155258179
target  1  0  0  1
predct  0  0  0  0
0.5525945425033569
target  1  0  0  1
predct  0  0  0  0
0.5517288446426392
target  1  0  0  1
predct  0  0  0  0
0.550872266292572
target  1  0  0  1
predct  0  0  0  0
0.5500140190124512
target  1  0  0  1
predct  0  0  0  0
0.5491811633110046
target  1  0  0  1
predct  0  0  0  0
0.5483520030975342
target  1  0  0  1
predct  0  0  0  0
0.5475126504898071
target  1  0  0  1
predct  0  0  0  0
0.5467065572738647
target  1  0  0  1
predct  0  0  0  0
0.5458909869194031
target  1  0  0  1
predct  0  0  0  0
0.5450935959815979
target  1  0  0  1
predct  0  0  0  0
0.5442981123924255
target  1  0  0  1
predct  0  0  0  0
0.543514609336853
target  1  0  0  1
predct  0  0  0  0
0.5427330732345581
target  1  0  0  1
predct  0  0  0  0
0.5419690012931824
target  1  0  0  1
predct  0  0  0  0
0.5412043333053589
target  1  0  0  1
predct  0  0  0  0
0.5404438972473145
target  1  0  0  1
predct  0  0  0  0
0.5396968126296997
target  1  0  0  1
predct  0  0  0  0
0.5389549732208252
target  1  0  0  1
predct  0  0  0  0
0.5382329225540161
target  1  0  0  1
predct  0  0  0  0
0.5375010967254639
target  1  0  0  1
predct  0  0  0  0
0.5367866158485413
target  1  0  0  1
predct  0  0  0  0
0.5360691547393799
target  1  0  0  1
predct  0  0  0  0
0.5353735089302063
target  1  0  0  1
predct  0  0  0  0
0.5346764326095581
target  1  0  0  1
predct  0  0  0  0
0.5339795351028442
target  1  0  0  1
predct  0  0  0  0
0.5333051681518555
target  1  0  0  1
predct  0  0  0  0
0.5326249599456787
target  1  0  0  1
predct  0  0  0  0
0.5319448709487915
target  1  0  0  1
predct  0  0  0  0
0.5312869548797607
target  1  0  0  1
predct  0  0  0  0
0.5306397080421448
target  1  0  0  1
predct  0  0  0  0
0.529991626739502
target  1  0  0  1
predct  0  0  0  0
0.5293443202972412
target  1  0  0  1
predct  0  0  0  0
0.5287067890167236
target  1  0  0  1
predct  0  0  0  0
0.5280739665031433
target  1  0  0  1
predct  0  0  0  0
0.5274503827095032
target  1  0  0  1
predct  0  0  0  0
0.5268409848213196
target  1  0  0  1
predct  0  0  0  0
0.5262224078178406
target  1  0  0  1
predct  0  0  0  0
0.5256248116493225
target  1  0  0  1
predct  0  0  0  0
0.5250262022018433
target  1  0  0  1
predct  0  0  0  0
0.5244249701499939
target  1  0  0  1
predct  0  0  0  0
0.5238385200500488
target  1  0  0  1
predct  0  0  0  0
0.5232646465301514
target  1  0  0  1
predct  0  0  0  0
0.5227017998695374
target  1  0  0  1
predct  0  0  0  0
0.5221250057220459
target  1  0  0  1
predct  0  0  0  0
0.5215622186660767
target  1  0  0  1
predct  0  0  0  0
0.5209983587265015
target  1  0  0  1
predct  0  0  0  0
0.5204557180404663
target  1  0  0  1
predct  0  0  0  0
0.5199100375175476
target  1  0  0  1
predct  0  0  0  0
0.5193754434585571
target  1  0  0  1
predct  0  0  0  0
0.5188430547714233
target  1  0  0  1
predct  0  0  0  0
0.5183134078979492
target  1  0  0  1
predct  0  0  0  0
0.5177967548370361
target  1  0  0  1
predct  0  0  0  0
0.5172780752182007
target  1  0  0  1
predct  0  0  0  0
0.5167653560638428
target  1  0  0  1
predct  0  0  0  0
0.516261637210846
target  1  0  0  1
predct  0  0  0  0
0.5157607197761536
target  1  0  0  1
predct  0  0  0  0
0.5152709484100342
target  1  0  0  1
predct  0  0  0  0
0.5147755742073059
target  1  0  0  1
predct  0  0  0  0
0.5142912864685059
target  1  0  0  1
predct  0  0  0  0
0.5138179659843445
target  1  0  0  1
predct  0  0  0  0
0.5133406519889832
target  1  0  0  1
predct  0  0  0  0
0.5128713250160217
target  1  0  0  1
predct  0  0  0  0
0.5124103426933289
target  1  0  0  1
predct  0  0  0  0
0.5119565725326538
target  1  0  0  1
predct  0  0  0  0
0.5114997625350952
target  1  0  0  1
predct  0  0  0  0
0.511050820350647
target  1  0  0  1
predct  0  0  0  0
0.5106053352355957
target  1  0  0  1
predct  0  0  0  0
0.5101768970489502
target  1  0  0  1
predct  0  0  0  0
0.5097408890724182
target  1  0  0  1
predct  0  0  0  0
0.5093106031417847
target  1  0  0  1
predct  0  0  0  0
0.5088900923728943
target  1  0  0  1
predct  0  0  0  0
0.5084632039070129
target  1  0  0  1
predct  0  0  0  0
0.5080502033233643
target  1  0  0  1
predct  0  0  0  0
0.5076354146003723
target  1  0  0  1
predct  0  0  0  0
0.5072289109230042
target  1  0  0  1
predct  0  0  0  0
0.5068353414535522
target  1  0  0  1
predct  0  0  0  0
0.5064445734024048
target  1  0  0  1
predct  0  0  0  0
0.5060452222824097
target  1  0  0  1
predct  0  0  0  0
0.5056560635566711
target  1  0  0  1
predct  0  0  0  0
0.5052724480628967
target  1  0  0  1
predct  0  0  0  0
0.5048882365226746
target  1  0  0  1
predct  0  0  0  0
0.5045223832130432
target  1  0  0  1
predct  0  0  0  0
0.5041466951370239
target  1  0  0  1
predct  0  0  0  0
0.5037842392921448
target  1  0  0  1
predct  0  0  0  0
0.5034222602844238
target  1  0  0  1
predct  0  0  0  0
0.5030651688575745
target  1  0  0  1
predct  0  0  0  0
0.5027077198028564
target  1  0  0  1
predct  0  0  0  0
0.5023578405380249
target  1  0  0  1
predct  0  0  0  0
0.5020185112953186
target  1  0  0  1
predct  0  0  0  0
0.5016751289367676
target  1  0  0  1
predct  0  0  0  0
0.5013354420661926
target  1  0  0  1
predct  0  0  0  0
0.5009928941726685
target  1  0  0  1
predct  0  0  0  0
0.5006711483001709
target  1  0  0  1
predct  0  0  0  0
0.5003464221954346
target  1  0  0  1
predct  0  0  0  0
0.500017523765564
target  1  0  0  1
predct  0  0  0  0
0.4996969699859619
target  1  0  0  1
predct  0  0  0  0
0.49938568472862244
target  1  0  0  1
predct  0  0  0  0
0.4990651607513428
target  1  0  0  1
predct  0  0  0  0
0.49876299500465393
target  1  0  0  1
predct  0  0  0  0
0.4984591007232666
target  1  0  0  1
predct  0  0  0  0
0.49816420674324036
target  1  0  0  1
predct  0  0  0  0
0.4978582561016083
target  1  0  0  1
predct  0  0  0  0
0.4975682497024536
target  1  0  0  1
predct  0  0  0  0
0.4972694516181946
target  1  0  0  1
predct  0  0  0  0
0.49699223041534424
target  1  0  0  1
predct  0  0  0  0
0.49670177698135376
target  1  0  0  1
predct  0  0  0  0
0.49641716480255127
target  1  0  0  1
predct  0  0  0  0
0.49614137411117554
target  1  0  0  1
predct  0  0  0  0
0.4958652853965759
target  1  0  0  1
predct  0  0  0  0
0.4955948293209076
target  1  0  0  1
predct  0  0  0  0
0.49532991647720337
target  1  0  0  1
predct  0  0  0  0
0.4950646162033081
target  1  0  0  1
predct  0  0  0  0
0.4948062598705292
target  1  0  0  1
predct  0  0  0  0
0.49454039335250854
target  1  0  0  1
predct  0  0  0  0
0.4942845404148102
target  1  0  0  1
predct  0  0  0  0
0.49403390288352966
target  1  0  0  1
predct  0  0  0  0
0.4937843680381775
target  1  0  0  1
predct  0  0  0  0
0.4935406744480133
target  1  0  0  1
predct  0  0  0  0
0.49329304695129395
target  1  0  0  1
predct  0  0  0  0
0.4930446147918701
target  1  0  0  1
predct  0  0  0  0
0.49280276894569397
target  1  0  0  1
predct  0  0  0  0
0.49257752299308777
target  1  0  0  1
predct  0  0  0  0
0.4923416078090668
target  1  0  0  1
predct  0  0  0  0
0.4921063482761383
target  1  0  0  1
predct  0  0  0  0
0.49189022183418274
target  1  0  0  1
predct  0  0  0  0
0.49165549874305725
target  1  0  0  1
predct  0  0  0  0
0.4914386570453644
target  1  0  0  1
predct  0  0  0  0
0.49121975898742676
target  1  0  0  1
predct  0  0  0  0
0.4910048246383667
target  1  0  0  1
predct  0  0  0  0
0.4907832443714142
target  1  0  0  1
predct  0  0  0  0
0.4905613362789154
target  1  0  0  1
predct  0  0  0  0
0.49036550521850586
target  1  0  0  1
predct  0  0  0  0
0.49015337228775024
target  1  0  0  1
predct  0  0  0  0
0.48995381593704224
target  1  0  0  1
predct  0  0  0  0
0.48974236845970154
target  1  0  0  1
predct  0  0  0  0
0.48955023288726807
target  1  0  0  1
predct  0  0  0  0
0.489345520734787
target  1  0  0  1
predct  0  0  0  0
0.48915165662765503
target  1  0  0  1
predct  0  0  0  0
0.48895832896232605
target  1  0  0  1
predct  0  0  0  0
0.4887630343437195
target  1  0  0  1
predct  0  0  0  0
0.48857519030570984
target  1  0  0  1
predct  0  0  0  0
0.4883871078491211
target  1  0  0  1
predct  0  0  0  0
0.4882015883922577
target  1  0  0  1
predct  0  0  0  0
0.48802652955055237
target  1  0  0  1
predct  0  0  0  0
0.48783639073371887
target  1  0  0  1
predct  0  0  0  0
0.48766133189201355
target  1  0  0  1
predct  0  0  0  0
0.4874873757362366
target  1  0  0  1
predct  0  0  0  0
0.4873100221157074
target  1  0  0  1
predct  0  0  0  0
0.48713961243629456
target  1  0  0  1
predct  0  0  0  0
0.4869716167449951
target  1  0  0  1
predct  0  0  0  0
0.48679661750793457
target  1  0  0  1
predct  0  0  0  0
0.48663145303726196
target  1  0  0  1
predct  0  0  0  0
0.4864644706249237
target  1  0  0  1
predct  0  0  0  0
0.4862975478172302
target  1  0  0  1
predct  0  0  0  0
0.4861358106136322
target  1  0  0  1
predct  0  0  0  0
0.4859757423400879
target  1  0  0  1
predct  0  0  0  0
0.4858146607875824
target  1  0  0  1
predct  0  0  0  0
0.4856683611869812
target  1  0  0  1
predct  0  0  0  0
0.4855043590068817
target  1  0  0  1
predct  0  0  0  0
0.4853549897670746
target  1  0  0  1
predct  0  0  0  0
0.48521220684051514
target  1  0  0  1
predct  0  0  0  0
0.48505428433418274
target  1  0  0  1
predct  0  0  0  0
0.48490849137306213
target  1  0  0  1
predct  0  0  0  0
0.4847695827484131
target  1  0  0  1
predct  0  0  0  0
0.4846273362636566
target  1  0  0  1
predct  0  0  0  0
0.48447442054748535
target  1  0  0  1
predct  0  0  0  0
0.4843389391899109
target  1  0  0  1
predct  0  0  0  0
0.48419222235679626
target  1  0  0  1
predct  0  0  0  0
0.4840541183948517
target  1  0  0  1
predct  0  0  0  0
0.4839138090610504
target  1  0  0  1
predct  0  0  0  0
0.4837850332260132
target  1  0  0  1
predct  0  0  0  0
0.48364800214767456
target  1  0  0  1
predct  0  0  0  0
0.48351430892944336
target  1  0  0  1
predct  0  0  0  0
0.4833889305591583
target  1  0  0  1
predct  0  0  0  0
0.4832632541656494
target  1  0  0  1
predct  0  0  0  0
0.4831262230873108
target  1  0  0  1
predct  0  0  0  0
0.4829978942871094
target  1  0  0  1
predct  0  0  0  0
0.482873797416687
target  1  0  0  1
predct  0  0  0  0
0.48275026679039
target  1  0  0  1
predct  0  0  0  0
0.48263147473335266
target  1  0  0  1
predct  0  0  0  0
0.482506662607193
target  1  0  0  1
predct  0  0  0  0
0.48239046335220337
target  1  0  0  1
predct  0  0  0  0
0.48226967453956604
target  1  0  0  1
predct  0  0  0  0
0.4821521043777466
target  1  0  0  1
predct  0  0  0  0
0.4820314645767212
target  1  0  0  1
predct  0  0  0  0
0.48191606998443604
target  1  0  0  1
predct  0  0  0  0
0.48180779814720154
target  1  0  0  1
predct  0  0  0  0
0.4816933870315552
target  1  0  0  1
predct  0  0  0  0
0.48157915472984314
target  1  0  0  1
predct  0  0  0  0
0.48147621750831604
target  1  0  0  1
predct  0  0  0  0
0.48135995864868164
target  1  0  0  1
predct  0  0  0  0
0.48124760389328003
target  1  0  0  1
predct  0  0  0  0
0.4811495542526245
target  1  0  0  1
predct  0  0  0  0
0.481037974357605
target  1  0  0  1
predct  0  0  0  0
0.4809342324733734
target  1  0  0  1
predct  0  0  0  0
0.4808306396007538
target  1  0  0  1
predct  0  0  0  0
0.48072388768196106
target  1  0  0  1
predct  0  0  0  0
0.4806179702281952
target  1  0  0  1
predct  0  0  0  0
0.4805232584476471
target  1  0  0  1
predct  0  0  0  0
0.48042404651641846
target  1  0  0  1
predct  0  0  0  0
0.4803297817707062
target  1  0  0  1
predct  0  0  0  0
0.480228453874588
target  1  0  0  1
predct  0  0  0  0
0.48012712597846985
target  1  0  0  1
predct  0  0  0  0
0.4800313115119934
target  1  0  0  1
predct  0  0  0  0
0.4799353778362274
target  1  0  0  1
predct  0  0  0  0
0.4798438549041748
target  1  0  0  1
predct  0  0  0  0
0.47974681854248047
target  1  0  0  1
predct  0  0  0  0
0.4796607196331024
target  1  0  0  1
predct  0  0  0  0
0.47956815361976624
target  1  0  0  1
predct  0  0  0  0
0.4794769585132599
target  1  0  0  1
predct  0  0  0  0
0.47938820719718933
target  1  0  0  1
predct  0  0  0  0
0.47929292917251587
target  1  0  0  1
predct  0  0  0  0
0.4792112112045288
target  1  0  0  1
predct  0  0  0  0
0.4791245758533478
target  1  0  0  1
predct  0  0  0  0
0.47903499007225037
target  1  0  0  1
predct  0  0  0  0
0.4789494574069977
target  1  0  0  1
predct  0  0  0  0
0.47886666655540466
target  1  0  0  1
predct  0  0  0  0
0.47878414392471313
target  1  0  0  1
predct  0  0  0  0
0.4786991775035858
target  1  0  0  1
predct  0  0  0  0
0.4786171317100525
target  1  0  0  1
predct  0  0  0  0
0.47853395342826843
target  1  0  0  1
predct  0  0  0  0
0.4784524738788605
target  1  0  0  1
predct  0  0  0  0
0.47837379574775696
target  1  0  0  1
predct  0  0  0  0
0.4782923758029938
target  1  0  0  1
predct  0  0  0  0
0.4782158136367798
target  1  0  0  1
predct  0  0  0  0
0.47813358902931213
target  1  0  0  1
predct  0  0  0  0
0.47806212306022644
target  1  0  0  1
predct  0  0  0  0
0.4779840111732483
target  1  0  0  1
predct  0  0  0  0
0.4779060482978821
target  1  0  0  1
predct  0  0  0  0
0.47783413529396057
target  1  0  0  1
predct  0  0  0  0
0.47775301337242126
target  1  0  0  1
predct  0  0  0  0
0.4776827096939087
target  1  0  0  1
predct  0  0  0  0
0.4776135981082916
target  1  0  0  1
predct  0  0  0  0
0.477541446685791
target  1  0  0  1
predct  0  0  0  0
0.4774642288684845
target  1  0  0  1
predct  0  0  0  0
0.47739312052726746
target  1  0  0  1
predct  0  0  0  0
0.4773232936859131
target  1  0  0  1
predct  0  0  0  0
0.47725412249565125
target  1  0  0  1
predct  0  0  0  0
0.47718265652656555
target  1  0  0  1
predct  0  0  0  0
0.4771106541156769
target  1  0  0  1
predct  0  0  0  0
0.47704818844795227
target  1  0  0  1
predct  0  0  0  0
0.4769759178161621
target  1  0  0  1
predct  0  0  0  0
0.47690895199775696
target  1  0  0  1
predct  0  0  0  0
0.4768458902835846
target  1  0  0  1
predct  0  0  0  0
0.4767790734767914
target  1  0  0  1
predct  0  0  0  0
0.47670602798461914
target  1  0  0  1
predct  0  0  0  0
0.47664564847946167
target  1  0  0  1
predct  0  0  0  0
0.47658368945121765
target  1  0  0  1
predct  0  0  0  0
0.4765118658542633
target  1  0  0  1
predct  0  0  0  0
0.4764525294303894
target  1  0  0  1
predct  0  0  0  0
0.47639000415802
target  1  0  0  1
predct  0  0  0  0
0.47633469104766846
target  1  0  0  1
predct  0  0  0  0
0.47626447677612305
target  1  0  0  1
predct  0  0  0  0
0.47619569301605225
target  1  0  0  1
predct  0  0  0  0
0.47614338994026184
target  1  0  0  1
predct  0  0  0  0
0.4760791063308716
target  1  0  0  1
predct  0  0  0  0
0.4760131239891052
target  1  0  0  1
predct  0  0  0  0
0.4759587347507477
target  1  0  0  1
predct  0  0  0  0
0.4758968651294708
target  1  0  0  1
predct  0  0  0  0
0.47584041953086853
target  1  0  0  1
predct  0  0  0  0
0.47578147053718567
target  1  0  0  1
predct  0  0  0  0
0.4757228195667267
target  1  0  0  1
predct  0  0  0  0
0.4756667912006378
target  1  0  0  1
predct  0  0  0  0
0.4756101071834564
target  1  0  0  1
predct  0  0  0  0
0.4755575656890869
target  1  0  0  1
predct  0  0  0  0
0.47549840807914734
target  1  0  0  1
predct  0  0  0  0
0.4754444360733032
target  1  0  0  1
predct  0  0  0  0
0.47538483142852783
target  1  0  0  1
predct  0  0  0  0
0.4753339886665344
target  1  0  0  1
predct  0  0  0  0
0.4752730131149292
target  1  0  0  1
predct  0  0  0  0
0.47522231936454773
target  1  0  0  1
predct  0  0  0  0
0.4751650393009186
target  1  0  0  1
predct  0  0  0  0
0.4751185178756714
target  1  0  0  1
predct  0  0  0  0
0.47505828738212585
target  1  0  0  1
predct  0  0  0  0
0.4750101864337921
target  1  0  0  1
predct  0  0  0  0
0.47495216131210327
target  1  0  0  1
predct  0  0  0  0
0.47489818930625916
target  1  0  0  1
predct  0  0  0  0
0.474847674369812
target  1  0  0  1
predct  0  0  0  0
0.47480419278144836
target  1  0  0  1
predct  0  0  0  0
0.47475260496139526
target  1  0  0  1
predct  0  0  0  0
0.47469890117645264
target  1  0  0  1
predct  0  0  0  0
0.47464293241500854
target  1  0  0  1
predct  0  0  0  0
0.4745933711528778
target  1  0  0  1
predct  0  0  0  0
0.474545955657959
target  1  0  0  1
predct  0  0  0  0
0.47449618577957153
target  1  0  0  1
predct  0  0  0  0
0.47445112466812134
target  1  0  0  1
predct  0  0  0  0
0.47439447045326233
target  1  0  0  1
predct  0  0  0  0
0.47434571385383606
target  1  0  0  1
predct  0  0  0  0
0.47430285811424255
target  1  0  0  1
predct  0  0  0  0
0.47425493597984314
target  1  0  0  1
predct  0  0  0  0
0.4742031991481781
target  1  0  0  1
predct  0  0  0  0
0.4741591513156891
target  1  0  0  1
predct  0  0  0  0
0.4741100072860718
target  1  0  0  1
predct  0  0  0  0
0.4740647077560425
target  1  0  0  1
predct  0  0  0  0
0.4740241765975952
target  1  0  0  1
predct  0  0  0  0
0.4739645719528198
target  1  0  0  1
predct  0  0  0  0
0.47392866015434265
target  1  0  0  1
predct  0  0  0  0
0.4738769233226776
target  1  0  0  1
predct  0  0  0  0
0.47383734583854675
target  1  0  0  1
predct  0  0  0  0
0.47378942370414734
target  1  0  0  1
predct  0  0  0  0
0.4737418293952942
target  1  0  0  1
predct  0  0  0  0
0.47369328141212463
target  1  0  0  1
predct  0  0  0  0
0.4736546277999878
target  1  0  0  1
predct  0  0  0  0
0.4736134111881256
target  1  0  0  1
predct  0  0  0  0
0.4735722839832306
target  1  0  0  1
predct  0  0  0  0
0.47352778911590576
target  1  0  0  1
predct  0  0  0  0
0.4734835624694824
target  1  0  0  1
predct  0  0  0  0
0.47343602776527405
target  1  0  0  1
predct  0  0  0  0
0.47339510917663574
target  1  0  0  1
predct  0  0  0  0
0.4733451306819916
target  1  0  0  1
predct  0  0  0  0
0.47330760955810547
target  1  0  0  1
predct  0  0  0  0
0.47326594591140747
target  1  0  0  1
predct  0  0  0  0
0.4732208251953125
target  1  0  0  1
predct  0  0  0  0
0.473182737827301
target  1  0  0  1
predct  0  0  0  0
0.4731367528438568
target  1  0  0  1
predct  0  0  0  0
0.4730971157550812
target  1  0  0  1
predct  0  0  0  0
0.4730556607246399
target  1  0  0  1
predct  0  0  0  0
0.4730161428451538
target  1  0  0  1
predct  0  0  0  0
0.47297734022140503
target  1  0  0  1
predct  0  0  0  0
0.4729330241680145
target  1  0  0  1
predct  0  0  0  0
0.47289586067199707
target  1  0  0  1
predct  0  0  0  0
0.4728546738624573
target  1  0  0  1
predct  0  0  0  0
0.4728182852268219
target  1  0  0  1
predct  0  0  0  0
0.4727668762207031
target  1  0  0  1
predct  0  0  0  0
0.4727354943752289
target  1  0  0  1
predct  0  0  0  0
0.47269299626350403
target  1  0  0  1
predct  0  0  0  0
0.47265028953552246
target  1  0  0  1
predct  0  0  0  0
0.4726113975048065
target  1  0  0  1
predct  0  0  0  0
0.4725760519504547
target  1  0  0  1
predct  0  0  0  0
0.4725426137447357
target  1  0  0  1
predct  0  0  0  0
0.4725054204463959
target  1  0  0  1
predct  0  0  0  0
0.4724593758583069
target  1  0  0  1
predct  0  0  0  0
0.47242286801338196
target  1  0  0  1
predct  0  0  0  0
0.472385436296463
target  1  0  0  1
predct  0  0  0  0
0.4723495841026306
target  1  0  0  1
predct  0  0  0  0
0.47231051325798035
target  1  0  0  1
predct  0  0  0  0
0.4722701609134674
target  1  0  0  1
predct  0  0  0  0
0.4722347855567932
target  1  0  0  1
predct  0  0  0  0
0.472199022769928
target  1  0  0  1
predct  0  0  0  0
0.47216013073921204
target  1  0  0  1
predct  0  0  0  0
0.4721272885799408
target  1  0  0  1
predct  0  0  0  0
0.472087025642395
target  1  0  0  1
predct  0  0  0  0
0.4720458686351776
target  1  0  0  1
predct  0  0  0  0
0.4720190465450287
target  1  0  0  1
predct  0  0  0  0
0.4719778001308441
target  1  0  0  1
predct  0  0  0  0
0.47194600105285645
target  1  0  0  1
predct  0  0  0  0
0.4719027578830719
target  1  0  0  1
predct  0  0  0  0
0.4718661308288574
target  1  0  0  1
predct  0  0  0  0
0.47183364629745483
target  1  0  0  1
predct  0  0  0  0
0.4718021750450134
target  1  0  0  1
predct  0  0  0  0
0.47176775336265564
target  1  0  0  1
predct  0  0  0  0
0.4717245101928711
target  1  0  0  1
predct  0  0  0  0
0.4716980457305908
target  1  0  0  1
predct  0  0  0  0
0.4716557264328003
target  1  0  0  1
predct  0  0  0  0
0.47162339091300964
target  1  0  0  1
predct  0  0  0  0
0.471585214138031
target  1  0  0  1
predct  0  0  0  0
0.471550852060318
target  1  0  0  1
predct  0  0  0  0
0.4715191423892975
target  1  0  0  1
predct  0  0  0  0
0.47148099541664124







And I’m using this data set This one.

Also this is how I’m turning it into categorical data.

dataset = pd.read_csv("mushrooms.csv")

cols = list(range(4,23))

dataset.drop(dataset.columns[cols],axis=1,inplace=True)

cdf = pd.DataFrame()

for i in list(dataset.columns):
    dataset[i] = pd.Categorical(dataset[i])
    cdf[i] = dataset[i].cat.codes

Thank you so much for your help Sir.

I looked at the code and think the dataset might be the issue, since it’s unnormalized and imbalanced.

There are several ways to deal with imbalanced datasets.
A good way is to use the WeightedRandomSampler.

Alternatively, since you are slicing the first 4000 samples, let’s just get 2000 samples of both classes.

y_one_idx = np.where(np.array(cdf['class']==1))[0][:2000]
y_zero_idx = np.where(np.array(cdf['class']==0))[0][:2000]
y_idx = np.hstack((y_one_idx, y_zero_idx))
np.random.shuffle(y_idx)

Y = torch.from_numpy(np.array(cdf['class'][y_idx], dtype=np.int64))
X = torch.from_numpy(np.column_stack((np.array(cdf['cap-shape'][y_idx], dtype=np.float32),
                                      np.array(cdf['cap-surface'][y_idx], dtype=np.float32),
                                      np.array(cdf['cap-color'][y_idx], dtype=np.float))))

Also we should try to normalize the data, because categorical data could also cause some problems in your model. Add these lines after loading the data:

X = X - torch.mean(X, dim=0)
X = X / torch.std(X, dim=0)

After some epochs, the model should start to learn. Printing the amount of 1s and 0s shows, that the output is also balanced.
However, the accuracy is not really good, so there is some work to do regarding the architecture. :wink:

2 Likes

Hi ptrblck sir,
Its me again. Can you please explain why after normalizing the data the network works so well.
And can you please tell me from where I can learn more about this amazing technique of “normalizing the data”. And also why should imbalanced dataset be a problem? shouldn’t the model learn the mapping regardless of the sequence of the data?
I’m sorry I’m a newbie, would love if you helped.
Thank you.

When the input values are roughly in the range (-1, 1) the gradients are more meaningful and training goes much faster.

In vague theory almost any model should be able to adjust to the data and make decent predictions, but it often doesn’t work that well. The building blocks of neural nets are designed to work best for inputs roughly in the range (-1, 1). Most building blocks work best when the weights are initialised carefully. Now if your inputs are in the range (80, 120), then your model will first have to learn to adjust to the data, and when it has finished adjusting then there is no guarantee that the model weights are still in a state conducive to further learning.

More info on data normalisation

1 Like

@jpeg729 answered the first question in a detailed way.

So for your second question:
Imbalanced data could be a problem, because your model might get too few “signals” from the rare data samples and thus overfit to the majority class.
Also measuring the accuracy might be problematic, since it can be pretty high by just predicting the majority class (see accuracy paradox).
In the end your model might have an accurcy of 99% just by predicting one class.

1 Like

Thanks a lot for the discussion. It helps a lot in solving the issue as a beginner. Some essay outlines would be more helpful in this process of resolving the problem.