Neural network not learning anything (output from each layer mostly zeros) - yes loss fn and optimizer are correctly written!

I am using a custom dataset with model as:

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(35, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 6)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

With the training function as:

self.criterion = nn.CrossEntropyLoss()   
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
   def train(self):
        valid_running_loss = 0.0
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")
        torch.backends.cudnn.benchmark = True
        self.model.to(device)
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.000_000_000_001)
        criterion = nn.CrossEntropyLoss()
        running_loss = 0
        loss_values = []
        for epoch in range(self.epochs):
            self.model.train()
            for batch_idx, (target, dat) in enumerate(self.train_loader):
                target, data =  Variable(
                    target.cuda()), Variable(dat.cuda())
                optimizer.zero_grad()
                output = self.model(dat)
                loss = criterion(output, target.flatten().to(device).long())
                loss.backward()
                optimizer.step()

                loss_values.append(running_loss/20)
                running_loss += loss.item()
                if batch_idx % 20 == 19:
                    print('Training [%d, %5d] loss: %.3f' %
                          (epoch + 1, batch_idx + 1, running_loss / 20))
                    running_loss = 0.0
                    torch.save(self.model.state_dict(), 'model.pt')
plt.plot(loss_values)
plt.xlabel("Batches")
plt.ylabel("Loss")
plt.show()

Output loss graph is as follows:
image
And output from each layer mostly follows the same pattern:


Layer #1 

tensor([[34.0111, 12.7092, 16.9817, 14.5254,  0.0000,  2.9979, 34.6398, 28.0957,
         24.6492,  0.0000,  0.0000, 20.6735, 55.1613,  0.0000, 43.1130, 38.0405,
          8.2961,  0.0000,  0.0000,  0.0000, 24.0659, 13.8045,  4.9426,  0.1992,
         34.9194,  0.0000,  0.0000, 19.0288,  0.0000, 22.8025,  0.0000,  0.0000,
          0.7127,  0.0000, 48.1853, 33.9753,  0.0000,  0.0000, 13.2362,  0.0000,
         44.4921,  9.9233, 38.6005, 35.1910, 19.6216, 37.5149,  1.4502,  0.0000,
         12.4940, 72.6231,  0.0000,  0.0000, 27.0568,  0.0000, 25.1401,  0.0000,
         11.9196, 19.1825,  0.0000, 21.7884,  0.0000, 27.3251, 12.9470, 16.1076,
         20.6509,  0.0000, 19.3813, 13.8918, 14.7036,  0.0000, 43.9978,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         35.3095,  0.0000,  0.0000, 12.3495,  0.0000, 17.0042,  0.0000, 50.3415,
          0.7187,  0.0000,  0.0000,  0.0000, 19.5300,  0.0000, 12.6171,  0.0000,
          0.0000, 40.4328, 63.1309, 25.9717,  9.9090,  0.0000,  0.0000,  7.4069,
          0.0000,  0.0000,  0.0000, 48.7994, 22.2934,  8.2384,  0.0000,  0.0000,
         16.6730,  0.0000, 21.5223, 36.6605, 37.9476,  0.0000,  0.0000, 30.9383,
          0.0000,  5.8143, 30.3140, 31.1911, 45.5149,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000, 57.5292, 21.0106,  0.0000,  0.0000, 11.2866,
          0.0000,  8.8533,  0.0000,  0.0000, 16.1030,  0.0000, 31.2619,  5.7975,
         12.3731, 14.3904,  0.0000,  0.0000,  0.0000,  0.0000,  1.4914,  0.0000,
          0.0000, 16.2442,  0.0000, 39.2010, 43.2472,  0.0000,  0.0000,  0.0000,
          0.0000,  8.3958,  0.0000, 13.6056,  0.0000,  0.0000, 86.4618, 31.2490,
          0.0000,  0.0000,  2.6972,  0.0000,  0.0000, 26.5139,  0.0000, 23.3579,
          0.0000,  0.0000,  0.0000, 10.0080,  0.0000,  0.0000,  0.0000, 15.1532,
          3.9325, 35.7198,  0.0000,  0.0000,  0.0000, 21.8514,  0.8783,  0.0000,
          0.0000, 11.6154,  0.0000, 32.9982,  4.7520, 28.7346,  0.0000,  0.0000,
         31.4094,  0.0000,  3.6026, 32.6338,  0.3227,  0.0000,  0.3136,  0.0000,
          0.0000,  9.4382,  0.0000,  0.0000, 17.7246,  0.0000, 23.2691,  0.0000,
         27.7171, 14.8556, 58.0410,  0.0000,  7.1684,  4.9152,  0.0000, 35.3398,
         26.2738,  0.0000, 25.8247,  0.0000,  0.0000,  0.0000,  0.0000, 25.2728,
         35.7325,  7.9791, 42.1267, 38.2015, 13.0649,  7.1808, 16.6197,  0.0000,
         25.6002,  0.2276,  0.0000, 28.3883, 11.9394, 41.1464, 11.5944,  0.0000,
          0.0000,  0.0000, 23.2719, 16.6102, 38.2222, 32.7788, 15.7401, 58.2293,
          1.6106,  0.0000,  0.0000, 35.0814,  0.0000,  0.0000,  3.4356,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000, 45.6004, 14.8991,  6.0531, 51.6026,
         16.8593,  0.0000,  0.0000,  0.0000, 55.9559,  0.0000,  0.0000, 36.4741,
         21.1376,  6.3189, 19.6905,  0.0000,  4.5537, 18.8644, 37.8007,  2.9587,
          0.0000,  0.0000, 66.9925,  2.1472,  0.0000, 49.2316,  0.0000, 41.7871,
         44.0987,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 36.4081,  0.0000,
          0.0000, 40.1042,  0.0000,  0.0000,  0.0000,  6.7286, 18.3722, 27.6182,
          0.0000, 67.7467,  9.7763, 18.0995,  8.0758,  7.9573,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  7.1732,  0.0000, 10.2361,  6.5490, 27.3207,
         95.2971, 17.5390, 43.5235,  0.0000,  0.0000,  0.0000, 27.7248,  0.0000,
         11.7532,  0.0000, 24.5198, 62.1982,  4.9184,  0.0000,  0.0000, 12.8589,
          0.0000,  0.0000,  0.0000,  3.4953, 50.0316, 22.7615,  0.0000,  0.7946,
          5.9959, 27.8512,  0.0000, 17.3078, 11.5306,  0.0000, 10.6378, 10.7233,
          0.0000,  4.2215, 20.2768,  0.0000,  0.0000,  0.0000,  0.0000, 30.5890,
         22.3280,  0.0000, 41.7865,  9.4994,  0.0000,  0.0000,  0.0000, 21.8536,
          0.0000, 12.1628, 26.2739,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, 15.3123, 32.7893,  0.0000,  0.0000, 13.0238,  0.0000,
          0.0000, 73.2508,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         17.9185,  0.0000,  0.0000,  0.0000,  0.0000, 21.6723,  0.0000, 56.8059,
         21.3461,  0.0000,  3.5318, 42.6378,  0.0000,  0.0000,  8.7609,  9.1071,
         19.7198, 29.4656, 12.3245,  0.0000,  0.0000,  6.2201,  0.0000,  0.0000,
          0.0000, 17.2576,  0.0000,  5.4993,  0.0000, 22.0809, 42.4508,  7.6554,
          0.0000, 14.8032,  8.5307, 17.6682,  0.0000,  4.4538,  3.2548, 31.2332,
          0.0000,  0.0000,  0.0000, 53.7162, 17.6550, 13.2346,  0.0000, 10.2985,
         11.6230,  0.0000, 45.5657,  7.3497,  0.0000,  1.8219,  0.0000, 70.1144,
          0.0000,  0.0000, 35.7366, 30.5250,  0.0000, 51.1208, 26.5028, 16.8218,
         28.1218,  0.0000, 36.6832, 12.0090,  0.7716,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  8.5589,  0.0000, 33.9515, 27.2925,  0.7363, 64.0930,
         50.3558, 54.7302, 62.5942,  0.0000,  0.0000,  1.9293,  0.0000, 25.3646,
          4.8337,  0.0000, 20.5091,  0.0000, 56.1850,  0.9221, 19.4342,  0.0000,
         66.2355,  0.0000,  0.0000,  0.0000,  0.0000,  7.8311,  0.0000,  0.0000,
         70.5462,  0.0000,  0.0000,  0.0000, 15.6784, 31.4369,  0.0000, 14.5324]],
       device='cuda:0', grad_fn=<ReluBackward0>)

Layer #2 

tensor([[17.1249,  0.0000,  0.0000,  9.0818,  6.6402,  0.0000, 10.8230,  0.0000,
          0.0000,  0.0000, 19.1912,  0.0000,  5.1139, 22.4080,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000, 11.7231,  8.7982, 13.2230,  6.9078,  0.0000,
          0.0000,  0.0000,  3.2706,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, 15.7249,  7.4991,  0.0000,  0.0000,  9.1082,  7.6445,  9.3315,
          1.8748,  0.0000,  7.9360,  0.0000,  0.0000,  0.0000,  0.0000, 17.4604,
         11.8764,  6.1336,  0.0000,  9.5193,  0.0000,  0.0000,  0.0000,  0.0000,
         22.4725,  0.0000,  0.0000,  0.0000, 17.5471,  0.0000,  8.5708, 11.9288,
          7.7524,  0.0000,  0.0000,  4.8687,  6.4660, 13.4811,  6.5080,  5.4127,
          0.0000,  0.0000,  0.0000,  0.0000, 19.6467,  0.0000,  0.0000, 26.5991,
          0.0000,  0.0000,  0.0000,  1.4286, 22.5212,  0.0000,  2.9779, 12.6172,
         17.1694,  0.0000,  0.0000,  0.0000,  0.0000,  3.8671,  8.2908,  0.0000,
          0.0000,  5.3878,  0.0000,  0.0000,  2.6435,  0.0000,  0.0000,  0.0000,
         31.1878, 16.2891,  3.2600,  0.6124,  0.0000, 15.0056,  0.0000,  0.0000,
          0.0000, 19.2099,  4.1495,  4.1315,  0.0514, 13.5338,  1.0504,  0.0000,
          0.0000,  0.0000,  0.0000, 11.8101,  0.0000,  0.0000,  0.0000, 13.4565,
         17.4728,  0.0000, 10.1245, 11.5265, 18.2403,  2.5224, 25.3509,  0.0000,
         11.8381,  0.0000,  0.0000,  2.9400,  0.0000, 23.7288,  5.2541,  0.0000,
          0.0000, 12.0983, 12.3099, 15.6219,  0.0000,  4.6333,  2.1624,  2.1363,
          2.8176,  3.7855,  0.0000, 10.6023, 28.2926,  7.8620,  0.0000,  0.0000,
          0.0000, 15.7391,  0.0000, 10.9450,  0.0000, 11.1348, 16.8085, 20.6935,
          0.0000,  0.0000,  0.0000, 11.9673,  0.0000, 10.4149,  0.0000,  0.0000,
          0.0000,  0.0000,  0.5347,  0.0000, 13.8853,  0.0000,  0.0000, 23.8881,
          5.6834,  0.0000, 14.2632, 15.1108,  0.0000,  0.0000,  0.0000, 12.5634,
         12.2963,  0.9804,  0.0000, 18.2825, 10.9668,  8.7040,  0.1205,  8.8042,
         11.9092,  5.7311, 12.1467,  0.0000,  0.0000,  9.8295,  5.2199,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  4.1317,  6.3019, 27.4451,  0.0000,
         19.1557,  0.0000, 11.3723, 13.3361,  6.6892,  0.0000,  0.0000,  1.9528,
          0.0000,  4.0795,  0.0000,  0.0000,  0.0000,  3.5099,  0.0000,  0.0000,
          0.0000, 16.9386,  0.0000,  0.0000,  7.6447,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.5518,  0.0000,  1.5412,
          0.0000,  0.0000,  0.0000,  1.8520,  0.0000,  0.0000,  0.0000,  1.8567,
          0.0000, 12.5432, 17.2627, 12.1782,  0.0000,  0.0000,  6.6105,  2.9548,
          5.1203,  0.0000,  0.6381,  8.4258,  0.0000,  0.0000, 16.4389, 19.8055,
          3.1554,  0.0000,  0.0000,  0.0000, 13.3941,  0.0000,  4.2483,  1.6484,
          0.0000,  0.0000,  0.0000,  1.0757,  0.0000, 12.4581, 16.7086, 14.6670,
         11.1585,  6.4158,  0.0000, 16.0432,  1.8949, 12.8711,  0.0000,  0.0000,
          0.0000,  1.9205,  0.0000, 16.2584,  8.1967,  4.1390,  5.6682,  0.0000,
          1.7621,  0.0000,  0.0000,  8.2641,  0.0000,  0.0000,  0.0000,  5.4671,
          9.1355,  0.0000, 23.0471,  0.0000, 14.0210,  0.0000,  3.1476,  0.0000,
          0.0000,  0.0000,  0.0000,  7.4607, 23.0116, 15.9541,  0.0000,  0.0000,
          2.8278,  8.5759, 10.3721,  0.0000, 17.0605, 31.8535,  6.9635,  0.0000,
          8.3336, 10.9779,  2.8000,  0.0000,  0.0000,  0.0000,  0.0000, 26.5629,
          0.0000,  2.4045,  0.0000,  0.0000,  0.5652,  0.0000, 13.7610,  9.7107,
          5.7010,  0.0000,  0.0000, 11.1692,  0.0000,  0.0000,  4.8460, 13.3004,
          0.0542,  4.6617,  0.4143,  0.0000,  0.0000, 15.8988,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000, 14.1985,  0.0000,  0.0000, 12.7816, 21.5568,
          4.8282,  3.6445,  6.9795,  7.4458,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, 12.9888, 16.1650, 10.7322, 12.3018,  0.0000,  0.0000,  7.6508,
          0.0000, 11.0397,  0.0000,  0.0000,  9.9304,  0.0000, 14.4211, 13.6603,
          7.4515,  0.0000,  0.0000,  8.8561,  0.0000, 10.5816,  0.0000, 13.6793,
          8.5912, 19.0544, 29.1780,  0.0000,  0.0000, 10.9369,  0.0000,  0.0000,
          6.0852,  0.3852,  1.5841,  0.0000,  0.0000, 12.2816, 13.9922,  9.5033,
          0.0000,  0.0000,  1.7762, 21.8631, 13.5751,  3.1538,  5.9521,  9.8035,
          2.2409,  0.0000,  7.9797,  0.0000, 11.7742,  0.0000,  0.0000,  0.0000,
         14.5848, 17.5528,  8.7062,  0.0000,  0.0000,  0.0000,  0.0000,  2.5330,
          0.0000,  3.4727,  1.3243,  1.6350,  0.0000,  2.3310, 16.4998,  9.0668,
          0.0000,  2.4043,  0.0000, 13.6607, 16.5940,  0.0000,  7.8046,  0.0000,
         20.6234, 10.3480,  0.0000, 16.7120,  0.0000, 11.0978,  4.4286,  0.0000,
          4.2564,  5.7708,  0.0000,  0.0000,  0.0000, 12.9946, 24.6138,  5.6800,
         34.5168,  1.7719, 17.0013, 13.6552,  0.0000,  0.0000,  5.6398, 10.8171,
          8.4392,  0.0000,  7.2894,  3.4434,  2.5727, 16.9701, 16.1560,  0.0000,
          0.0000, 13.8336,  0.0000,  0.0000,  0.0000, 18.6943,  0.1918,  1.5877,
          0.0000,  6.1855,  5.8500,  0.8094, 10.7234,  0.0000,  4.7110,  4.7390]],
       device='cuda:0', grad_fn=<ReluBackward0>)

Layer #3 

tensor([[1.6433, 1.1880, 1.4009, 0.0000, 4.4047, 4.4773]], device='cuda:0',
       grad_fn=<ReluBackward0>)

There is very little difference between the outputs of Layer #3 from successive batches.

I have also played around with LR and number of weights in the FC layers. The loss magnitude usually increases, with the shape mostly being the same (but never decreasing :().

Let me know if any other information is needed…Any help is appreciated:)!

Could you remove the last F.relu and rerun the training?
I don’t think it’s a good idea to limit the logits to [0, +Inf] for a multi-class classification.

@ptrblck, Thanks for the reply. I did try that - but no luck.image

The Loss still remains high. Other things I have tried - changing the architecture, changing the learning rate, and changing the batchsize/number of epochs but still the same result.

I am absolutely dumbstruck as what to try next. To be clear though, there is something wrong in the architecture right?

Isn’t it just the loss plotting?
First you start with 0, then your add the new loss (divide by 20), then you add again the new loss to the old loss and divide again by 20, until 20 times done…you’re just accumulation your loss, while chaining the divide by 20 (its bad explained but idk how to do it better). Maybe store your loss in a list until 20 batches reach and plot the mean of those or you write

running_loss += loss.item() / 20

and save them only after 20 batches to your loss_values

loss_values.append(running_loss)
running_loss  = 0

@Caruso 20 is the batchsize, sorry for the poor code I wrote there! Even if I do the the way you suggest, the values don’t change much…

I might be wrong here, but I remember running into a similar problem before. I see you are wrapping your data in the Variable. Could you try removing that and maybe pass the data to the model just as tensors and see if it changes?

@a_d, sadly this didn’t work too…Thanks for the reply though!

Just fixed it - all I had to do was increase the number of epochs to an absurdly large number!
That’s deep learning for you haha. Thanks everyone:)!