I’m having a strange problem using .argmax(). When I have a tensor located on my GPU, argmax() doesn’t seem to give me the correct index. If i move my array to the CPU I get the expected result.
While running on the GPU:
print(array)
print(array.max())
print(array.argmax())
print(array[array.argmax().item()])
gives the output:
tensor([ -11.5442, -8.5333, -18.0345, -85.9370, -40.4294, -100.8801,
-81.5367, -158.3000, -53.0409, -79.1826, -70.0517, -49.0895,
-82.3455, -146.0170, -50.0242, -8.0940, -129.8060, -89.6011,
-51.4888, -61.9345, -68.5567, -82.3089, -94.2280, -68.8929,
-58.0690, -41.6044, -73.2325, -98.5677, -71.2841, -97.9249,
-78.4366, -119.5655, -38.1226, -26.0892, -68.2670, -98.7564,
-99.1900, -43.8435, -56.8168, -104.1809, -80.6755, -104.1263,
-110.3228, -62.4219, -86.5573, -56.0344, -54.7963, -52.7687,
-68.4890, -42.4576, -85.1173, -70.2056, -25.9964, -49.6356,
-69.7168, -2.9612, -82.8546, -72.6364, -143.3515, -87.6566,
-28.4261, -56.5061, -42.5236, -31.2142, -48.2933, -88.7260,
-26.8897, -36.2937, -80.7760, -58.9123, -70.4392, -31.6963,
-23.6501, -40.7379, -43.2897, -24.0806, -129.6917, -62.9375,
-61.3152, -51.7130, 11.2619, -28.1152, -56.1424, -62.4392,
-31.8498, -50.9248, -26.5526, -70.4671, -101.5490, -53.1732,
-114.1311, -120.1272, -56.8597, -47.0254, -77.9264, -95.7280,
-52.8543, -76.9921, -44.5555, -65.8190, -80.3501, -72.5908,
-57.3906, -83.8305, -49.4801, -56.9889, -111.5182, -37.6801,
-79.2351, -23.4743, -9.0946, -41.2241, -33.2034, -67.3494,
-10.5925, -28.0212, -54.5078, -36.9357, -83.2554, -72.8076,
-69.7290, -24.7960, -65.6029, -22.1891, -42.4741, -37.2619,
-77.5729, -39.8775, -67.8837, -36.7911, -35.0757, -62.9799,
-126.9124, -79.0780, -60.1444, -86.6097, -97.2016, -55.6896,
-105.0834, -97.1455, -126.9090, -92.9244, -60.8984, -23.8531,
-45.8114, -22.9533, -97.4331, -34.1982, -34.8573, -27.0410,
-89.6775], device='cuda:0')
tensor(11.2619, device='cuda:0')
tensor(5, device='cuda:0')
tensor(-100.8801, device='cuda:0')
While running on the CPU:
array = array.cpu()
print(array.max())
print(array.argmax())
print(array[array.argmax().item()])
gives the correct result:
tensor(11.2619)
tensor(80)
tensor(11.2619)
Does anyone know what is happening here?