I am building a model for a mixed categories(with continuous and categorical variables using word embeddings with 2 linear layers.Its a binary classification problem.My training loss isn’t changing.It would be great if some one help.
I assume you are not shuffling the data using your DataLoader
, since the losses seem to be quite deterministic.
Are you detaching the computation graph accidentally in your model’s forward
method?
Also, are you using a custom loss function or a PyTorch one?
Skimming through the code I cannot see anything obviously wrong.
Could you check some gradients after calling loss.backward()
?
You could print them using:
print(model.outp.weight.grad)
print(model.embs[0].weight.grad)
If you see some None
values, something is broken. Otherwise we would have to dig a bit deeper.
PS: You can add code snippets using three backticks ```
This makes it easier to copy your code and debug it.
These are the values I see when I print gradients of outputs and embeddings.
tensor([[-3.1623e-02, -2.6350e-02, -1.7309e-02, -6.1525e-03, -2.7841e-02,
-2.6789e-02, -4.0986e-02, 1.1163e-02, 4.6091e-02, 2.4598e-02,
-2.6607e-02, -2.6708e-03, 2.0545e-03, 1.3116e-02, -1.9225e-02,
1.8009e-02, -7.0806e-02, -1.9697e-02, -1.6139e-02, 5.6194e-02,
-8.7657e-02, -1.0188e-02, -1.5533e-02, -1.6141e-02, 2.0929e-02,
-6.0323e-02, 5.4425e-02, -1.4522e-02, -1.2154e-02, 7.9580e-04,
-9.7717e-03, -2.8365e-02, 1.0887e-02, -3.1096e-02, -7.4927e-03,
-8.5017e-03, -3.7971e-03, -1.1083e-03, -2.4572e-02, 4.2590e-03,
-4.3712e-02, -4.1331e-02, -2.7154e-02, -5.2349e-02, -5.9448e-02,
3.4661e-02, -4.4277e-02, -1.3610e-02, 2.9865e-05, -1.3078e-02,
-1.5869e-02, -5.1470e-03, 5.3503e-03, 5.3278e-02, -5.2516e-02,
-1.4310e-02, 6.2403e-03, 6.2356e-03, -2.6851e-02, 5.1220e-03,
-5.9708e-02, -6.0736e-02, 1.2995e-02, -2.0492e-02, -2.8461e-02,
-3.2872e-02, -3.5934e-02, 2.7493e-02, -2.5135e-02, -3.2079e-02,
-2.5752e-03, 4.8966e-03, -3.5285e-02, -3.0157e-02, -3.8634e-02,
2.8084e-02, -4.8500e-02, 3.0333e-03, -5.3200e-02, -1.5738e-02,
-2.8506e-02, -3.4487e-03, 4.0806e-02, 2.0268e-02, -1.5578e-02,
-1.2922e-02, 3.8184e-03, 2.1328e-02, 2.1187e-02, 4.2885e-02,
-7.1759e-02, -1.4990e-02, 7.2830e-02, -2.4386e-02, -4.7653e-03,
-2.1303e-02, -6.4600e-02, 4.3422e-02, 1.4864e-02, -4.8731e-02,
3.9190e-02, 4.2304e-03, 1.5949e-02, -1.5027e-02, -2.4129e-02,
-1.9980e-02, -3.4105e-02, 1.3001e-02, -1.7983e-02, 3.3862e-02,
-4.7518e-02, -9.4020e-03, -2.7877e-02, -4.5408e-02, 1.3022e-02,
1.0548e-02, -1.8343e-02, -1.6114e-02, -1.3189e-02, 1.1662e-03,
1.2975e-02, 5.7757e-02, 1.3618e-02, -2.8639e-02, -3.9885e-02,
-5.5618e-02, -2.3285e-02, -6.9030e-05, -5.9235e-02, -3.2400e-02,
-2.9226e-02, -6.5258e-02, -3.5231e-02, -3.4279e-02, 2.3250e-02,
1.7370e-03, -9.1888e-03, -1.1221e-02, 1.5109e-02, -5.0923e-02,
4.2830e-02, -3.7013e-03, -5.5487e-02, -1.8025e-02, 3.4591e-02,
7.0318e-02, 3.1747e-02, 3.8931e-02, -1.8351e-02, -9.3918e-03,
-3.4855e-02, -2.8251e-02, 3.6465e-03, 3.7266e-02, -9.6779e-02,
5.0921e-03, -2.8991e-02, 6.8520e-03, 9.7878e-03, -2.8903e-02,
-1.6398e-02, -1.5751e-03, -1.8113e-02, -4.0756e-03, -1.7850e-02,
-1.0662e-02, -7.5500e-02, -8.3943e-02, -1.1067e-02, -2.7160e-02,
-2.2912e-02, -2.0004e-03, 2.8608e-02, -4.0171e-02, 1.7596e-02,
-2.1385e-02, 2.4988e-02, -1.4092e-02, -1.5201e-02, 4.9917e-02,
1.0833e-02, 8.0092e-03, 4.3952e-02, 3.6990e-02, 4.9020e-03,
-2.5192e-02, -2.1895e-02, 2.8979e-02, 9.2480e-04, -3.5609e-02,
-1.1323e-02, 2.3988e-02, -6.2246e-03, 1.1821e-02, -2.8954e-02,
8.4986e-04, 9.8768e-03, 1.9290e-02, -3.8933e-02, 3.4948e-03,
3.5462e-02, -3.0178e-02, -5.5254e-03, 1.4226e-03, -5.1838e-03,
-2.4739e-02, -1.9127e-02, -4.4689e-03, -6.8476e-02, -2.7002e-02,
-6.2134e-02, -2.6698e-02, -2.2844e-02, -2.1762e-02, -1.5398e-02,
-3.7795e-02, -5.8228e-03, 1.8611e-02, -4.9444e-03, -2.4235e-02,
3.1925e-04, 4.2813e-02, -6.5582e-03, -4.9425e-02, 2.1836e-02,
-2.9703e-02, -9.8694e-02, -4.1492e-02, -2.1444e-02, 5.7452e-03,
-3.1103e-02, 7.3129e-03, 1.7087e-03, 3.5748e-02, 1.8767e-02,
1.7446e-02, -1.8929e-02, -7.9848e-02, 7.3374e-04, 2.8640e-02,
2.9640e-03, -8.9208e-03, 7.3254e-03, -1.5448e-02, -1.9306e-03,
-6.7582e-02, 2.0013e-02, -2.0523e-02, -2.6685e-03, -5.7062e-02,
4.8704e-03, 5.6583e-02, 8.2884e-02, 2.9788e-02, -8.3444e-04,
-2.8515e-02, -7.6275e-03, 5.2626e-03, 5.4316e-03, -3.3708e-02,
-1.2808e-02, 1.4598e-02, -4.6423e-02, 3.5671e-02, 2.8519e-02,
3.9008e-02, -1.2556e-02, -1.3970e-02, -1.8020e-02, -1.3051e-02,
2.3781e-03, -1.7507e-02, -5.3407e-02, 6.4795e-03, -6.4268e-03,
-2.6923e-02, 3.6155e-02, 9.6456e-03, -2.1730e-02, -1.3292e-02,
-8.0084e-03, -5.0160e-03, -1.4873e-02, -6.1040e-02, -2.1741e-02,
3.1036e-02, -2.3044e-03, 9.3973e-04, 1.0799e-02, 1.3042e-02,
-1.1986e-02, 1.9876e-02, 1.9210e-02, -5.8544e-02, -5.8250e-02,
2.0442e-02, 1.5879e-02, -1.5362e-02, -2.3519e-02, -6.4135e-02,
-1.0467e-02, 3.9432e-03, 2.2478e-02, -1.4194e-02, -2.5212e-02,
-3.7196e-02, -1.1732e-02, 2.6087e-02, -6.3475e-03, 1.1559e-02,
4.3142e-03, -9.9321e-03, 2.0739e-02, -4.9002e-02, 1.8722e-03,
-5.0559e-02, -1.5068e-02, -4.1370e-02, 3.7180e-02, 6.4168e-02,
1.0479e-02, -2.0402e-03, 1.2502e-02, -3.5288e-02, -5.9934e-04,
1.5053e-02, -6.1099e-03, -2.6714e-02, 9.9228e-03, 2.1643e-02,
4.9815e-03, -7.8759e-04, -2.0461e-02, -3.7384e-02, -1.5176e-02,
-3.6167e-02, 6.8868e-02, -4.4166e-03, -3.4864e-02, -3.3638e-02,
1.9990e-02, 6.2415e-03, -2.7875e-02, -2.7577e-02, -4.2171e-02,
3.8956e-02, 6.3638e-02, 2.7637e-03, 1.9848e-02, 2.3079e-02,
3.6044e-02, -1.7952e-02, -1.7918e-02, 6.9048e-02, 3.5189e-02,
1.9060e-02, 4.7532e-02, -6.4268e-02, -3.3150e-02, -2.4915e-02,
-1.1871e-02, 7.7971e-03, 1.6802e-02, -6.7766e-03, -6.0042e-02,
6.7299e-02, 1.8261e-02, -3.1467e-02, 1.8212e-03, 3.1181e-03,
-5.2383e-02, -2.1561e-02, -3.2257e-02, 5.0996e-02, -1.5949e-02,
-3.0735e-02, 6.4452e-02, -9.1218e-03, -1.4836e-02, 4.4239e-02,
-3.9941e-02, -4.5105e-03, -1.2341e-02, -3.3473e-02, -8.3403e-03,
-3.4777e-02, -1.1807e-02, -4.1881e-02, -2.6477e-02, 8.2249e-03,
-1.3171e-02, 3.9644e-02, -2.6130e-02, -1.0104e-02, -1.8376e-02,
-2.3920e-02, -4.3445e-02, 2.3253e-02, -2.4548e-02, -5.0419e-02,
2.6859e-02, -3.2329e-03, -2.5221e-02, -1.6766e-03, -3.2432e-02,
-3.7509e-02, -5.2284e-02, -9.5738e-02, -4.9757e-02, -3.4016e-02,
6.3242e-02, -4.0427e-02, 1.0408e-02, -1.0845e-02, 5.4411e-02,
6.9872e-02, 7.1449e-03, 1.8818e-02, 1.1320e-02, -2.1045e-02,
2.0577e-02, 3.1659e-02, -1.6194e-03, -2.5389e-02, -4.5406e-02,
-4.8093e-03, 1.1047e-03, -1.5277e-02, -2.2355e-02, 1.1083e-02,
-5.2939e-02, -5.4026e-02, 7.7325e-03, -3.7525e-02, -4.8978e-03,
2.6263e-03, -4.6975e-02, 1.0279e-02, 3.5905e-02, -1.6772e-02,
4.8599e-03, 3.5422e-02, -6.1347e-02, 1.3385e-02, 1.8079e-02,
-1.1148e-02, 3.1458e-03, -5.7805e-02, -4.9902e-02, -4.0210e-02,
-1.7585e-02, 1.9289e-02, -6.7882e-03, 3.2867e-02, -2.5522e-02,
5.1807e-03, -1.2839e-03, -4.5885e-03, 4.4152e-02, -5.5047e-02,
-4.6421e-02, -3.1257e-02, 5.3186e-02, 4.6530e-02, 3.6265e-02,
-5.1960e-02, -1.4269e-02, -8.0705e-03, -2.9653e-02, -5.6376e-02,
-3.9184e-02, 6.1064e-02, 7.2560e-04, -4.7297e-03, -4.9365e-02,
8.8795e-03, 3.3561e-02, 3.2388e-02, -2.4099e-02, 9.2907e-04,
-1.6416e-02, 2.3246e-02, 8.7383e-02, -2.3976e-03, -1.9337e-02,
-3.6302e-02, -4.7792e-03, -3.4619e-02, -3.7661e-02, 7.1462e-04,
-8.2786e-02, -2.0263e-02, 1.2210e-02, 1.0964e-02, -1.2815e-02,
4.0427e-02, -4.1053e-02, -7.4282e-02, -1.5011e-02, -2.0446e-02]],
dtype=torch.float64)
tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0081, 0.0107, 0.0280, -0.0123, 0.0071],
[-0.0076, -0.0142, -0.0112, 0.0065, -0.0092]], dtype=torch.float64)
The gradients seem to be valid and also your loss moves.
I would suggest to play around with the hyperparameters of your training, e.g. lower the learning rate and observe if the training loss decreases a bit steadier.