PyTorch Multi-dimensional Max

If I have a multi-dimensional tensor, what would be the easiest way to get the maximum over all dimensions and the multi-dimensional index for the maximum?


The easy way is to have a recursive function that will take the max and store the index of each dimension one by one.
A harder and faster way would be to linearize the tensor in 1D (with .view(-1) assuming it’s contiguous). And get the max and it’s index. Then compute the original position of this 1D point in the original Tensor based on the original Tensor’s sizes.

1 Like