I wondered how to find a maximum index along 2 different dimensions.
I followed the suggestions of the thread: Get indices of the max of a 2D Tensor
But I have got a no deterministic problem:
xx = torch.rand((1, 8, 512, 512, 9, 9))
pos_x = (xx == torch.amax(xx, (-1, -2), keepdim=True)).nonzero(as_tuple=True)
print(len(pos_x[5]))
2097155
print(xx[:,:,:,:,0,0].flatten().shape)
torch.Size([2097152])
In particular, this size changes at each iteration.
Where am I wrong?
The code works for me and outputs the same shape for each iteration:
xx = torch.rand((1, 8, 512, 512, 9, 9))
for _ in range(10):
pos_x = (xx == torch.amax(xx, (-1, -2), keepdim=True)).nonzero(as_tuple=True)
print(len(pos_x[5]))
print(xx[:,:,:,:,0,0].flatten().shape)
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
# 2097154
# torch.Size([2097152])
Maybe I explained it in a bad way, sorry about that.
I would extract the indices that maximize an index, taken on a kernel 9x9.
I expected that len(pos_x[5]) == xx[:,:,:,:,0,0].flatten().shape, being the max calculated on the last two indices. I think that for some repeated maximum this number could change. In some other cases, the length of the list is higher than the length of the flatten tensor.
Nevermind, I have found a workaround:
c = torch.argmax(rho, dim=-1, keepdims=True)
r = torch.argmax(torch.take_along_dim(rho, c, dim=-1), dim=-2, keepdims=True)
c = torch.take_along_dim(c, r, dim=-2)
It seems to work