@ptrblck Yes, that’s right case 1 is the working case with high train and test accuracy and 2 and 3 are accuracy drops.
I’ll check out the code. Thanks!
Differences:
The output is a dictionary of a prediction
and a reconstructed image
.
Case 1:
Pred: tensor([[-0.7676, 0.0572, -0.2913, 0.2349, 0.5005, -1.1794, 0.4590, 0.2181,
-2.1164, -1.3106], grad_fn=<AddmmBackward>)
image: tensor([[[[ 0.5258, 0.4485, 0.7366, ..., 0.5316, 0.3521, 0.2883],
[ 0.6817, 0.7242, 0.1581, ..., 0.6326, 0.4853, 0.2883],
[-0.0272, -0.1409, -0.1605, ..., 1.1503, 0.4435, 0.2883],
...,
[ 0.4976, 0.5694, 0.9019, ..., 0.6768, 0.6060, 0.2883],
[ 0.4517, 0.0776, 0.3625, ..., 0.4241, 0.5714, 0.2883],
[ 0.2883, 0.2883, 0.2883, ..., 0.2883, 0.2883, 0.2883]],
[[ 0.6070, 0.4169, 0.8111, ..., 0.7332, 0.5797, 0.2265],
[-0.0116, 0.6817, 0.0530, ..., 0.9291, 0.5413, 0.2265],
[ 1.1952, 1.0150, 0.8058, ..., 0.5320, 1.0475, 0.2265],
...,
[ 0.5384, 0.6242, 0.5577, ..., 0.6715, 0.8561, 0.2265],
[ 0.4094, 0.6437, 0.2078, ..., 0.8093, 0.5136, 0.2265],
[ 0.2265, 0.2265, 0.2265, ..., 0.2265, 0.2265, 0.2265]],
[[ 0.8264, 0.8105, 0.3731, ..., 0.6286, 0.7814, 0.2504],
[ 0.6968, 0.9409, 1.2026, ..., 0.8817, 0.0785, 0.2504],
[-0.1297, 0.3147, -0.0102, ..., 0.6539, -0.0164, 0.2504],
...,
[ 0.7424, 0.9211, 0.6916, ..., 0.6464, 0.1591, 0.2504],
[-0.0205, 0.3855, 0.0819, ..., 0.4896, 0.4849, 0.2504],
[ 0.2504, 0.2504, 0.2504, ..., 0.2504, 0.2504, 0.2504]]]
Case 2:
Pred: tensor([[ 2.2933, -0.8836, 2.3208, -1.1150, 0.1603, -0.7463, -0.6391, -0.2642,
-3.1370, -0.2606], grad_fn=<AddmmBackward>)
image: tensor([[[[ 0.2977, 0.4950, 0.5564, ..., 0.5029, 0.4113, 0.2883],
[ 0.3410, 0.5465, 0.5328, ..., 0.4874, 0.4242, 0.2883],
[ 0.3324, 0.5292, 0.3376, ..., 0.5825, 0.4509, 0.2883],
...,
[ 0.3655, 0.5581, 0.4323, ..., 0.5095, 0.5081, 0.2883],
[ 0.2985, 0.3943, 0.5151, ..., 0.3950, 0.4275, 0.2883],
[ 0.2883, 0.2883, 0.2883, ..., 0.2883, 0.2883, 0.2883]],
[[ 0.2727, 0.3019, 0.4901, ..., 0.3821, 0.4321, 0.2265],
[ 0.3167, 0.5233, 0.3820, ..., 0.5307, 0.4235, 0.2265],
[ 0.7326, 0.5277, 0.6648, ..., 0.6426, 0.5306, 0.2265],
...,
[ 0.3303, 0.5173, 0.4646, ..., 0.5538, 0.3874, 0.2265],
[ 0.5248, 0.4554, 0.6289, ..., 0.5178, 0.3928, 0.2265],
[ 0.2265, 0.2265, 0.2265, ..., 0.2265, 0.2265, 0.2265]],
[[ 0.5768, 0.4982, 0.4561, ..., 0.4663, 0.4978, 0.2504],
[ 0.5807, 0.6176, 0.7808, ..., 0.5769, 0.3198, 0.2504],
[ 0.2692, 0.6964, 0.5180, ..., 0.6125, 0.4428, 0.2504],
...,
[ 0.6028, 0.6727, 0.6887, ..., 0.6235, 0.3214, 0.2504],
[ 0.1296, 0.3230, -0.0013, ..., 0.3186, 0.2721, 0.2504],
[ 0.2504, 0.2504, 0.2504, ..., 0.2504, 0.2504, 0.2504]]]