Hi, I got the above error of the nn.CrossEntropyLoss()
Traceback (most recent call last):
File "./food101/food101_nas.py", line 467, in <module>
loss1 = criterion(output1, labels)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py", line 862, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py", line 1550, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py", line 1407, in nll_loss
return torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of type torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #2 'target'
there is the part of the code
criterion = nn.CrossEntropyLoss().cuda()
for epoch in range(50): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(data_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
labels = labels.type(torch.cuda.FloatTensor)
# zero the parameter gradients
optimizer2.zero_grad()
# forward + backward + optimize
output1, output2 = net2(inputs)
loss1 = criterion(output1, labels)
of loss1 = criterion(output1, labels) the programm stops
And i try to output the value of labels and output1, here are the results:
tensor([[ 0.3472, 0.2209, -0.5995, 0.5815, -0.5395, -0.0155, 2.4635, 2.0652,
-0.2173, 2.0970, -0.9288, 1.7439, 2.1594, -0.3784, 0.3356, 0.4073,
0.3096, 1.0668, -0.8825, 2.8969, -1.3266, -1.5186, 1.8897, -1.0376,
0.4083, -3.1226, 1.1905, 3.9962, 2.8334, 2.7869, -0.5127, -4.0223,
-1.6923, 2.1538, -0.2082, -2.7689, -0.9950, -1.1264, 0.7648, -1.8015,
-1.1787, -0.6646, 1.1710, 2.5576, 1.2553, -1.6907, -1.9629, -2.4180,
-1.6066, -0.0622, -3.0201, 1.1205, -0.0081, 1.2506, 0.7384, -0.4972,
-0.2058, 0.5839, 0.6814, 3.6870, -1.0864, -1.5504, 0.2249, 0.2436,
0.4892, -1.1034, 4.5987, -1.6100, 1.3722, -0.0430, 0.4520, 0.5140,
1.2593, -1.5004, -1.1094, -0.2125, 2.0717, -2.4672, 1.7870, -0.2758,
-0.5859, -0.5422, -3.7564, -2.0177, 1.5797, -3.3137, 0.0896, -1.4499,
-2.0254, 2.7423, -4.7888, 0.0125, -1.6988, 1.9800, -0.4304, 0.0367,
1.4657, -0.0832, -1.1632, 0.6637, 0.2876],
[-1.0760, -2.3349, -0.2378, -0.0260, -1.9087, -0.5062, 2.7556, 2.8177,
0.6555, 3.7096, 1.3478, 0.7260, -0.1354, 0.8245, -0.0649, -1.3409,
-0.6787, -0.8940, -1.0808, 2.2057, -0.9994, -2.2736, 1.9989, 0.8456,
2.5386, -1.7708, -1.1134, 2.9402, 2.4160, 1.9480, -2.1800, -1.9614,
-1.7699, -2.3918, 1.1136, -0.7473, 2.8408, 0.5136, 0.1066, -0.1198,
-2.7388, 1.0143, -0.2919, -1.4613, -1.4885, -0.7984, 2.3917, -0.3652,
-0.9415, 0.4829, 3.9927, 0.8021, 2.0672, 2.1484, -1.5458, 2.6886,
1.7743, -1.4343, -1.9249, 1.4623, 1.9435, -0.5579, -0.5783, -1.1952,
0.1120, -1.6303, 4.4920, -1.4937, 1.7977, 1.6284, -3.0435, 0.0606,
1.1746, -1.1102, -2.1779, 1.1458, 1.6065, -1.6678, 2.6499, -2.2413,
1.3299, -1.9646, -3.0804, 1.4887, 2.5108, -1.3850, 0.2503, 0.6424,
0.0941, 4.5158, -0.7314, -0.6402, -0.4716, 2.0102, -1.2146, 0.9128,
-1.8152, -0.6944, -1.7823, -1.8926, 2.2456],
[ 0.1301, -3.9154, 0.0021, -3.5201, -1.2929, 0.5564, 1.8020, 2.0753,
1.8213, -0.0316, 0.6688, -0.8543, -0.5184, -1.0957, -1.2135, -0.1007,
0.3097, -1.0970, -1.4665, 3.7728, -0.8016, -0.8768, -0.3117, 0.5693,
2.2932, -1.0881, -0.9895, 2.3643, 4.3349, 3.4266, -0.0193, 0.8639,
-1.9764, -0.2480, 0.7688, 0.1922, 0.1700, 2.3806, 1.2501, -2.3265,
-0.7687, 0.3803, -0.1902, -1.1115, -0.3638, -0.7665, -3.4929, -0.9755,
0.0180, 1.2906, 0.8780, 0.8353, 4.5028, 1.3965, 0.2595, 1.9385,
-0.6293, -1.8133, -0.1595, 1.8847, 2.1214, 1.1154, 0.1183, -0.9148,
0.1752, 0.5887, 1.3333, 0.5452, -0.5026, 0.2785, -0.4085, 2.6408,
-1.1813, 0.3552, -4.1691, -1.4687, 2.3110, -1.3233, 2.4228, -0.4022,
1.1015, 1.4589, -1.6527, 0.6953, 0.5717, -2.6545, 0.2260, 4.5496,
-2.9178, 2.8515, -2.8933, -1.0091, -1.3126, 1.0695, -0.9941, 1.5877,
0.3424, 0.4332, -0.4486, -3.1223, 2.7039],
[-1.3958, -1.6683, -0.3244, -1.8912, 1.0913, -2.5487, 3.6205, 1.9874,
-2.7681, 3.0277, 1.5184, 2.5858, -1.9175, -0.3118, -0.1551, -3.0004,
1.2457, 2.3147, 0.4921, -0.6332, -2.1855, 2.8018, 7.3446, -2.2081,
3.4884, 1.3904, 4.0372, 6.9340, 9.7749, 7.5051, -9.1866, -0.4472,
-2.2257, 3.7294, 5.9419, -4.6607, 3.6848, -2.0865, -5.7268, 0.2872,
-5.6435, 3.3136, 1.3436, 4.2567, -2.8445, -1.1460, 1.8317, -6.8653,
1.1093, -1.0434, -4.2880, -2.9261, 4.1342, 4.2560, -1.7228, 8.0642,
-3.9592, 3.0952, -1.5354, -0.6765, 0.6810, -1.4194, -1.0910, -4.5485,
3.3561, -3.4030, 3.1959, 4.4669, 3.1454, 1.4789, -0.7964, 5.3485,
-0.0700, -1.6954, -5.4768, 0.3164, 2.7968, -0.5237, 0.3240, -2.6250,
-4.5888, -0.8958, 0.5304, -1.8871, -3.0368, -4.3204, 8.9819, 3.9074,
-5.0982, 1.2394, -2.4371, 0.1130, -5.1622, 3.8740, -2.7696, 0.7971,
3.4686, -0.9827, -1.6262, -2.0382, 0.9371]],
device='cuda:0', grad_fn=<ThAddmmBackward>)
tensor([21., 45., 43., 87.], device='cuda:0')
I tried few of the solutions online, but nothing worked.