I am using a custom dataset with model as:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(35, 512)
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, 6)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
return x
With the training function as:
self.criterion = nn.CrossEntropyLoss()
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
def train(self):
valid_running_loss = 0.0
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
self.model.to(device)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.000_000_000_001)
criterion = nn.CrossEntropyLoss()
running_loss = 0
loss_values = []
for epoch in range(self.epochs):
self.model.train()
for batch_idx, (target, dat) in enumerate(self.train_loader):
target, data = Variable(
target.cuda()), Variable(dat.cuda())
optimizer.zero_grad()
output = self.model(dat)
loss = criterion(output, target.flatten().to(device).long())
loss.backward()
optimizer.step()
loss_values.append(running_loss/20)
running_loss += loss.item()
if batch_idx % 20 == 19:
print('Training [%d, %5d] loss: %.3f' %
(epoch + 1, batch_idx + 1, running_loss / 20))
running_loss = 0.0
torch.save(self.model.state_dict(), 'model.pt')
plt.plot(loss_values)
plt.xlabel("Batches")
plt.ylabel("Loss")
plt.show()
Output loss graph is as follows:
And output from each layer mostly follows the same pattern:
Layer #1
tensor([[34.0111, 12.7092, 16.9817, 14.5254, 0.0000, 2.9979, 34.6398, 28.0957,
24.6492, 0.0000, 0.0000, 20.6735, 55.1613, 0.0000, 43.1130, 38.0405,
8.2961, 0.0000, 0.0000, 0.0000, 24.0659, 13.8045, 4.9426, 0.1992,
34.9194, 0.0000, 0.0000, 19.0288, 0.0000, 22.8025, 0.0000, 0.0000,
0.7127, 0.0000, 48.1853, 33.9753, 0.0000, 0.0000, 13.2362, 0.0000,
44.4921, 9.9233, 38.6005, 35.1910, 19.6216, 37.5149, 1.4502, 0.0000,
12.4940, 72.6231, 0.0000, 0.0000, 27.0568, 0.0000, 25.1401, 0.0000,
11.9196, 19.1825, 0.0000, 21.7884, 0.0000, 27.3251, 12.9470, 16.1076,
20.6509, 0.0000, 19.3813, 13.8918, 14.7036, 0.0000, 43.9978, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
35.3095, 0.0000, 0.0000, 12.3495, 0.0000, 17.0042, 0.0000, 50.3415,
0.7187, 0.0000, 0.0000, 0.0000, 19.5300, 0.0000, 12.6171, 0.0000,
0.0000, 40.4328, 63.1309, 25.9717, 9.9090, 0.0000, 0.0000, 7.4069,
0.0000, 0.0000, 0.0000, 48.7994, 22.2934, 8.2384, 0.0000, 0.0000,
16.6730, 0.0000, 21.5223, 36.6605, 37.9476, 0.0000, 0.0000, 30.9383,
0.0000, 5.8143, 30.3140, 31.1911, 45.5149, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 57.5292, 21.0106, 0.0000, 0.0000, 11.2866,
0.0000, 8.8533, 0.0000, 0.0000, 16.1030, 0.0000, 31.2619, 5.7975,
12.3731, 14.3904, 0.0000, 0.0000, 0.0000, 0.0000, 1.4914, 0.0000,
0.0000, 16.2442, 0.0000, 39.2010, 43.2472, 0.0000, 0.0000, 0.0000,
0.0000, 8.3958, 0.0000, 13.6056, 0.0000, 0.0000, 86.4618, 31.2490,
0.0000, 0.0000, 2.6972, 0.0000, 0.0000, 26.5139, 0.0000, 23.3579,
0.0000, 0.0000, 0.0000, 10.0080, 0.0000, 0.0000, 0.0000, 15.1532,
3.9325, 35.7198, 0.0000, 0.0000, 0.0000, 21.8514, 0.8783, 0.0000,
0.0000, 11.6154, 0.0000, 32.9982, 4.7520, 28.7346, 0.0000, 0.0000,
31.4094, 0.0000, 3.6026, 32.6338, 0.3227, 0.0000, 0.3136, 0.0000,
0.0000, 9.4382, 0.0000, 0.0000, 17.7246, 0.0000, 23.2691, 0.0000,
27.7171, 14.8556, 58.0410, 0.0000, 7.1684, 4.9152, 0.0000, 35.3398,
26.2738, 0.0000, 25.8247, 0.0000, 0.0000, 0.0000, 0.0000, 25.2728,
35.7325, 7.9791, 42.1267, 38.2015, 13.0649, 7.1808, 16.6197, 0.0000,
25.6002, 0.2276, 0.0000, 28.3883, 11.9394, 41.1464, 11.5944, 0.0000,
0.0000, 0.0000, 23.2719, 16.6102, 38.2222, 32.7788, 15.7401, 58.2293,
1.6106, 0.0000, 0.0000, 35.0814, 0.0000, 0.0000, 3.4356, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 45.6004, 14.8991, 6.0531, 51.6026,
16.8593, 0.0000, 0.0000, 0.0000, 55.9559, 0.0000, 0.0000, 36.4741,
21.1376, 6.3189, 19.6905, 0.0000, 4.5537, 18.8644, 37.8007, 2.9587,
0.0000, 0.0000, 66.9925, 2.1472, 0.0000, 49.2316, 0.0000, 41.7871,
44.0987, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 36.4081, 0.0000,
0.0000, 40.1042, 0.0000, 0.0000, 0.0000, 6.7286, 18.3722, 27.6182,
0.0000, 67.7467, 9.7763, 18.0995, 8.0758, 7.9573, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 7.1732, 0.0000, 10.2361, 6.5490, 27.3207,
95.2971, 17.5390, 43.5235, 0.0000, 0.0000, 0.0000, 27.7248, 0.0000,
11.7532, 0.0000, 24.5198, 62.1982, 4.9184, 0.0000, 0.0000, 12.8589,
0.0000, 0.0000, 0.0000, 3.4953, 50.0316, 22.7615, 0.0000, 0.7946,
5.9959, 27.8512, 0.0000, 17.3078, 11.5306, 0.0000, 10.6378, 10.7233,
0.0000, 4.2215, 20.2768, 0.0000, 0.0000, 0.0000, 0.0000, 30.5890,
22.3280, 0.0000, 41.7865, 9.4994, 0.0000, 0.0000, 0.0000, 21.8536,
0.0000, 12.1628, 26.2739, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 15.3123, 32.7893, 0.0000, 0.0000, 13.0238, 0.0000,
0.0000, 73.2508, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
17.9185, 0.0000, 0.0000, 0.0000, 0.0000, 21.6723, 0.0000, 56.8059,
21.3461, 0.0000, 3.5318, 42.6378, 0.0000, 0.0000, 8.7609, 9.1071,
19.7198, 29.4656, 12.3245, 0.0000, 0.0000, 6.2201, 0.0000, 0.0000,
0.0000, 17.2576, 0.0000, 5.4993, 0.0000, 22.0809, 42.4508, 7.6554,
0.0000, 14.8032, 8.5307, 17.6682, 0.0000, 4.4538, 3.2548, 31.2332,
0.0000, 0.0000, 0.0000, 53.7162, 17.6550, 13.2346, 0.0000, 10.2985,
11.6230, 0.0000, 45.5657, 7.3497, 0.0000, 1.8219, 0.0000, 70.1144,
0.0000, 0.0000, 35.7366, 30.5250, 0.0000, 51.1208, 26.5028, 16.8218,
28.1218, 0.0000, 36.6832, 12.0090, 0.7716, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 8.5589, 0.0000, 33.9515, 27.2925, 0.7363, 64.0930,
50.3558, 54.7302, 62.5942, 0.0000, 0.0000, 1.9293, 0.0000, 25.3646,
4.8337, 0.0000, 20.5091, 0.0000, 56.1850, 0.9221, 19.4342, 0.0000,
66.2355, 0.0000, 0.0000, 0.0000, 0.0000, 7.8311, 0.0000, 0.0000,
70.5462, 0.0000, 0.0000, 0.0000, 15.6784, 31.4369, 0.0000, 14.5324]],
device='cuda:0', grad_fn=<ReluBackward0>)
Layer #2
tensor([[17.1249, 0.0000, 0.0000, 9.0818, 6.6402, 0.0000, 10.8230, 0.0000,
0.0000, 0.0000, 19.1912, 0.0000, 5.1139, 22.4080, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 11.7231, 8.7982, 13.2230, 6.9078, 0.0000,
0.0000, 0.0000, 3.2706, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 15.7249, 7.4991, 0.0000, 0.0000, 9.1082, 7.6445, 9.3315,
1.8748, 0.0000, 7.9360, 0.0000, 0.0000, 0.0000, 0.0000, 17.4604,
11.8764, 6.1336, 0.0000, 9.5193, 0.0000, 0.0000, 0.0000, 0.0000,
22.4725, 0.0000, 0.0000, 0.0000, 17.5471, 0.0000, 8.5708, 11.9288,
7.7524, 0.0000, 0.0000, 4.8687, 6.4660, 13.4811, 6.5080, 5.4127,
0.0000, 0.0000, 0.0000, 0.0000, 19.6467, 0.0000, 0.0000, 26.5991,
0.0000, 0.0000, 0.0000, 1.4286, 22.5212, 0.0000, 2.9779, 12.6172,
17.1694, 0.0000, 0.0000, 0.0000, 0.0000, 3.8671, 8.2908, 0.0000,
0.0000, 5.3878, 0.0000, 0.0000, 2.6435, 0.0000, 0.0000, 0.0000,
31.1878, 16.2891, 3.2600, 0.6124, 0.0000, 15.0056, 0.0000, 0.0000,
0.0000, 19.2099, 4.1495, 4.1315, 0.0514, 13.5338, 1.0504, 0.0000,
0.0000, 0.0000, 0.0000, 11.8101, 0.0000, 0.0000, 0.0000, 13.4565,
17.4728, 0.0000, 10.1245, 11.5265, 18.2403, 2.5224, 25.3509, 0.0000,
11.8381, 0.0000, 0.0000, 2.9400, 0.0000, 23.7288, 5.2541, 0.0000,
0.0000, 12.0983, 12.3099, 15.6219, 0.0000, 4.6333, 2.1624, 2.1363,
2.8176, 3.7855, 0.0000, 10.6023, 28.2926, 7.8620, 0.0000, 0.0000,
0.0000, 15.7391, 0.0000, 10.9450, 0.0000, 11.1348, 16.8085, 20.6935,
0.0000, 0.0000, 0.0000, 11.9673, 0.0000, 10.4149, 0.0000, 0.0000,
0.0000, 0.0000, 0.5347, 0.0000, 13.8853, 0.0000, 0.0000, 23.8881,
5.6834, 0.0000, 14.2632, 15.1108, 0.0000, 0.0000, 0.0000, 12.5634,
12.2963, 0.9804, 0.0000, 18.2825, 10.9668, 8.7040, 0.1205, 8.8042,
11.9092, 5.7311, 12.1467, 0.0000, 0.0000, 9.8295, 5.2199, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 4.1317, 6.3019, 27.4451, 0.0000,
19.1557, 0.0000, 11.3723, 13.3361, 6.6892, 0.0000, 0.0000, 1.9528,
0.0000, 4.0795, 0.0000, 0.0000, 0.0000, 3.5099, 0.0000, 0.0000,
0.0000, 16.9386, 0.0000, 0.0000, 7.6447, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 4.5518, 0.0000, 1.5412,
0.0000, 0.0000, 0.0000, 1.8520, 0.0000, 0.0000, 0.0000, 1.8567,
0.0000, 12.5432, 17.2627, 12.1782, 0.0000, 0.0000, 6.6105, 2.9548,
5.1203, 0.0000, 0.6381, 8.4258, 0.0000, 0.0000, 16.4389, 19.8055,
3.1554, 0.0000, 0.0000, 0.0000, 13.3941, 0.0000, 4.2483, 1.6484,
0.0000, 0.0000, 0.0000, 1.0757, 0.0000, 12.4581, 16.7086, 14.6670,
11.1585, 6.4158, 0.0000, 16.0432, 1.8949, 12.8711, 0.0000, 0.0000,
0.0000, 1.9205, 0.0000, 16.2584, 8.1967, 4.1390, 5.6682, 0.0000,
1.7621, 0.0000, 0.0000, 8.2641, 0.0000, 0.0000, 0.0000, 5.4671,
9.1355, 0.0000, 23.0471, 0.0000, 14.0210, 0.0000, 3.1476, 0.0000,
0.0000, 0.0000, 0.0000, 7.4607, 23.0116, 15.9541, 0.0000, 0.0000,
2.8278, 8.5759, 10.3721, 0.0000, 17.0605, 31.8535, 6.9635, 0.0000,
8.3336, 10.9779, 2.8000, 0.0000, 0.0000, 0.0000, 0.0000, 26.5629,
0.0000, 2.4045, 0.0000, 0.0000, 0.5652, 0.0000, 13.7610, 9.7107,
5.7010, 0.0000, 0.0000, 11.1692, 0.0000, 0.0000, 4.8460, 13.3004,
0.0542, 4.6617, 0.4143, 0.0000, 0.0000, 15.8988, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 14.1985, 0.0000, 0.0000, 12.7816, 21.5568,
4.8282, 3.6445, 6.9795, 7.4458, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 12.9888, 16.1650, 10.7322, 12.3018, 0.0000, 0.0000, 7.6508,
0.0000, 11.0397, 0.0000, 0.0000, 9.9304, 0.0000, 14.4211, 13.6603,
7.4515, 0.0000, 0.0000, 8.8561, 0.0000, 10.5816, 0.0000, 13.6793,
8.5912, 19.0544, 29.1780, 0.0000, 0.0000, 10.9369, 0.0000, 0.0000,
6.0852, 0.3852, 1.5841, 0.0000, 0.0000, 12.2816, 13.9922, 9.5033,
0.0000, 0.0000, 1.7762, 21.8631, 13.5751, 3.1538, 5.9521, 9.8035,
2.2409, 0.0000, 7.9797, 0.0000, 11.7742, 0.0000, 0.0000, 0.0000,
14.5848, 17.5528, 8.7062, 0.0000, 0.0000, 0.0000, 0.0000, 2.5330,
0.0000, 3.4727, 1.3243, 1.6350, 0.0000, 2.3310, 16.4998, 9.0668,
0.0000, 2.4043, 0.0000, 13.6607, 16.5940, 0.0000, 7.8046, 0.0000,
20.6234, 10.3480, 0.0000, 16.7120, 0.0000, 11.0978, 4.4286, 0.0000,
4.2564, 5.7708, 0.0000, 0.0000, 0.0000, 12.9946, 24.6138, 5.6800,
34.5168, 1.7719, 17.0013, 13.6552, 0.0000, 0.0000, 5.6398, 10.8171,
8.4392, 0.0000, 7.2894, 3.4434, 2.5727, 16.9701, 16.1560, 0.0000,
0.0000, 13.8336, 0.0000, 0.0000, 0.0000, 18.6943, 0.1918, 1.5877,
0.0000, 6.1855, 5.8500, 0.8094, 10.7234, 0.0000, 4.7110, 4.7390]],
device='cuda:0', grad_fn=<ReluBackward0>)
Layer #3
tensor([[1.6433, 1.1880, 1.4009, 0.0000, 4.4047, 4.4773]], device='cuda:0',
grad_fn=<ReluBackward0>)
There is very little difference between the outputs of Layer #3 from successive batches.
I have also played around with LR and number of weights in the FC layers. The loss magnitude usually increases, with the shape mostly being the same (but never decreasing :().
Let me know if any other information is needed…Any help is appreciated:)!