Torch tensor == torch.amax(tensor, (-2,-1)) error

I wondered how to find a maximum index along 2 different dimensions.
I followed the suggestions of the thread: Get indices of the max of a 2D Tensor

But I have got a no deterministic problem:

xx = torch.rand((1, 8, 512, 512, 9, 9))
pos_x = (xx == torch.amax(xx, (-1, -2), keepdim=True)).nonzero(as_tuple=True)

print(len(pos_x[5]))

2097155

print(xx[:,:,:,:,0,0].flatten().shape)

torch.Size([2097152])

In particular, this size changes at each iteration.

Where am I wrong?

The code works for me and outputs the same shape for each iteration:

xx = torch.rand((1, 8, 512, 512, 9, 9))

for _ in range(10):
    pos_x = (xx == torch.amax(xx, (-1, -2), keepdim=True)).nonzero(as_tuple=True)
    print(len(pos_x[5]))
    print(xx[:,:,:,:,0,0].flatten().shape)
    
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])

Maybe I explained it in a bad way, sorry about that.
I would extract the indices that maximize an index, taken on a kernel 9x9.
I expected that len(pos_x[5]) == xx[:,:,:,:,0,0].flatten().shape, being the max calculated on the last two indices. I think that for some repeated maximum this number could change. In some other cases, the length of the list is higher than the length of the flatten tensor.

Nevermind, I have found a workaround:

c = torch.argmax(rho, dim=-1, keepdims=True)
r = torch.argmax(torch.take_along_dim(rho, c, dim=-1), dim=-2, keepdims=True)
c = torch.take_along_dim(c, r, dim=-2)

It seems to work