Pytorch has the function to find the global maximum value, or the maximum values and indices along given dimension. How can I find the indices of the max in an N-dimensional variable? (for example if N=3 , the indices corresponding to the max value, like (3,7,10) and use it in indexing another tensor)

When I tried

max_val1, idx1 = torch.max(my_tensor,0)

max_val2, max_idx2 = torch.max(max_val1, 0)

max_idx1 = idx1[max_idx2]

indexing a tensor with an object of type LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.

is the error. And it is not flexible for changing N. Is there any more direct way? How can I solve it?