Finding indices of a global maximum value in a variable

Pytorch has the function to find the global maximum value, or the maximum values and indices along given dimension. How can I find the indices of the max in an N-dimensional variable? (for example if N=3 , the indices corresponding to the max value, like (3,7,10) and use it in indexing another tensor)

When I tried

max_val1, idx1 = torch.max(my_tensor,0)
max_val2, max_idx2 = torch.max(max_val1, 0)
max_idx1 = idx1[max_idx2]

indexing a tensor with an object of type LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.

is the error. And it is not flexible for changing N. Is there any more direct way? How can I solve it?

Hi,

one way to solve it is to flatten the tensor, take the maximum index along dimension 0 and then unravel the index, either using numpy or your own logic (probably you will come up with something less clumsy :slight_smile: ):

rawmaxidx = mytensor.view(-1).min(0)[1]
idx = []
for adim in list(mytensor.size())[::-1]:
    idx.append(rawmaxidx%adim)
    rawmaxidx = rawmaxidx / adim
idx = torch.cat(idx)

(Note that pytorch / on LongTensor is similar to python2 / or python3 // for ints.

Best regards

Thomas

1 Like

Based on above question related to indexes, i have a question too.
I get max and index for the max, from 1xn vector containing LongTensor values. index obtained is also of LongTensor type. I want to further use this to obtain a value from dictionary.

dict = {1: ‘hello’ , 2: ‘world’}
how to mention the int in dict using LongTensor.
dict[index] gives error.
Please let me know

Thanks

If you have a tensor t, an index t [0] (with as many dimensions as t is a plain python value. For a Variable v use v.data[0].

Best regards

Thomas

Thank you, both of your suggestions worked for me, I modified your code to find indices first.

but I have a question, why cant we directly put v.data and get a single element instead? It is a very basic operation

Variable v is say wrapper of tensor t than v.data equals t. Calling .data just reveals the tensor that Variable wraps around

Thanks, it works for me using v.data[0]