.argmax() fails on GPU, works on CPU

I’m having a strange problem using .argmax(). When I have a tensor located on my GPU, argmax() doesn’t seem to give me the correct index. If i move my array to the CPU I get the expected result.

While running on the GPU:

 print(array)
 print(array.max())
 print(array.argmax())
 print(array[array.argmax().item()])

gives the output:

tensor([ -11.5442,   -8.5333,  -18.0345,  -85.9370,  -40.4294, -100.8801,
         -81.5367, -158.3000,  -53.0409,  -79.1826,  -70.0517,  -49.0895,
         -82.3455, -146.0170,  -50.0242,   -8.0940, -129.8060,  -89.6011,
         -51.4888,  -61.9345,  -68.5567,  -82.3089,  -94.2280,  -68.8929,
         -58.0690,  -41.6044,  -73.2325,  -98.5677,  -71.2841,  -97.9249,
         -78.4366, -119.5655,  -38.1226,  -26.0892,  -68.2670,  -98.7564,
         -99.1900,  -43.8435,  -56.8168, -104.1809,  -80.6755, -104.1263,
        -110.3228,  -62.4219,  -86.5573,  -56.0344,  -54.7963,  -52.7687,
         -68.4890,  -42.4576,  -85.1173,  -70.2056,  -25.9964,  -49.6356,
         -69.7168,   -2.9612,  -82.8546,  -72.6364, -143.3515,  -87.6566,
         -28.4261,  -56.5061,  -42.5236,  -31.2142,  -48.2933,  -88.7260,
         -26.8897,  -36.2937,  -80.7760,  -58.9123,  -70.4392,  -31.6963,
         -23.6501,  -40.7379,  -43.2897,  -24.0806, -129.6917,  -62.9375,
         -61.3152,  -51.7130,   11.2619,  -28.1152,  -56.1424,  -62.4392,
         -31.8498,  -50.9248,  -26.5526,  -70.4671, -101.5490,  -53.1732,
        -114.1311, -120.1272,  -56.8597,  -47.0254,  -77.9264,  -95.7280,
         -52.8543,  -76.9921,  -44.5555,  -65.8190,  -80.3501,  -72.5908,
         -57.3906,  -83.8305,  -49.4801,  -56.9889, -111.5182,  -37.6801,
         -79.2351,  -23.4743,   -9.0946,  -41.2241,  -33.2034,  -67.3494,
         -10.5925,  -28.0212,  -54.5078,  -36.9357,  -83.2554,  -72.8076,
         -69.7290,  -24.7960,  -65.6029,  -22.1891,  -42.4741,  -37.2619,
         -77.5729,  -39.8775,  -67.8837,  -36.7911,  -35.0757,  -62.9799,
        -126.9124,  -79.0780,  -60.1444,  -86.6097,  -97.2016,  -55.6896,
        -105.0834,  -97.1455, -126.9090,  -92.9244,  -60.8984,  -23.8531,
         -45.8114,  -22.9533,  -97.4331,  -34.1982,  -34.8573,  -27.0410,
         -89.6775], device='cuda:0')
tensor(11.2619, device='cuda:0')
tensor(5, device='cuda:0')
tensor(-100.8801, device='cuda:0')

While running on the CPU:

array = array.cpu()
print(array.max())
print(array.argmax())
print(array[array.argmax().item()])

gives the correct result:

tensor(11.2619)
tensor(80)
tensor(11.2619)

Does anyone know what is happening here?

1 Like

Hi,

It works fine for me:

import torch

t = torch.tensor([ -11.5442,   -8.5333,  -18.0345,  -85.9370,  -40.4294, -100.8801,
         -81.5367, -158.3000,  -53.0409,  -79.1826,  -70.0517,  -49.0895,
         -82.3455, -146.0170,  -50.0242,   -8.0940, -129.8060,  -89.6011,
         -51.4888,  -61.9345,  -68.5567,  -82.3089,  -94.2280,  -68.8929,
         -58.0690,  -41.6044,  -73.2325,  -98.5677,  -71.2841,  -97.9249,
         -78.4366, -119.5655,  -38.1226,  -26.0892,  -68.2670,  -98.7564,
         -99.1900,  -43.8435,  -56.8168, -104.1809,  -80.6755, -104.1263,
        -110.3228,  -62.4219,  -86.5573,  -56.0344,  -54.7963,  -52.7687,
         -68.4890,  -42.4576,  -85.1173,  -70.2056,  -25.9964,  -49.6356,
         -69.7168,   -2.9612,  -82.8546,  -72.6364, -143.3515,  -87.6566,
         -28.4261,  -56.5061,  -42.5236,  -31.2142,  -48.2933,  -88.7260,
         -26.8897,  -36.2937,  -80.7760,  -58.9123,  -70.4392,  -31.6963,
         -23.6501,  -40.7379,  -43.2897,  -24.0806, -129.6917,  -62.9375,
         -61.3152,  -51.7130,   11.2619,  -28.1152,  -56.1424,  -62.4392,
         -31.8498,  -50.9248,  -26.5526,  -70.4671, -101.5490,  -53.1732,
        -114.1311, -120.1272,  -56.8597,  -47.0254,  -77.9264,  -95.7280,
         -52.8543,  -76.9921,  -44.5555,  -65.8190,  -80.3501,  -72.5908,
         -57.3906,  -83.8305,  -49.4801,  -56.9889, -111.5182,  -37.6801,
         -79.2351,  -23.4743,   -9.0946,  -41.2241,  -33.2034,  -67.3494,
         -10.5925,  -28.0212,  -54.5078,  -36.9357,  -83.2554,  -72.8076,
         -69.7290,  -24.7960,  -65.6029,  -22.1891,  -42.4741,  -37.2619,
         -77.5729,  -39.8775,  -67.8837,  -36.7911,  -35.0757,  -62.9799,
        -126.9124,  -79.0780,  -60.1444,  -86.6097,  -97.2016,  -55.6896,
        -105.0834,  -97.1455, -126.9090,  -92.9244,  -60.8984,  -23.8531,
         -45.8114,  -22.9533,  -97.4331,  -34.1982,  -34.8573,  -27.0410,
         -89.6775])
t = t.cuda()
print(t.max())
print(t.argmax())

Does that script work on your side?
If so, can you give a small script that reproduces the issue?

Hi,
Your script works fine for me. I’m having trouble creating a self-contained script to reproduce the issue because the array in my example is the output of a network, but if I generate a tensor using torch.randn() and move it to my GPU then .argmax() works fine.

As far as I can tell, this issue only happens when the true maximum value is in the second half of the tensor (index 75 and onwards of a 151 element tensor).

I understand this is very hard to help with when I can’t provide a self-contained example, but any suggestions of how I could investigate on my side are much appreciated!

Could you print the size and stride of the array that does not work and report it here?

The array in my example above is created by indexing my network’s output, array = output[0,:,i,j,k]:

>>> output.shape
torch.Size([1, 151, 137, 199, 178])
>>> output.stride()
(732774914, 4852814, 35422, 178, 1)

>>> array.shape
torch.Size([151])
>>> array.stride()
(1,)

Thanks for the code snippet. The error seems to be related to this issue.

1 Like

That’s really helpful, thanks!