Pytorch has the function to find the global maximum value, or the maximum values and indices along given dimension. How can I find the indices of the max in an N-dimensional variable? (for example if N=3 , the indices corresponding to the max value, like (3,7,10) and use it in indexing another tensor)
indexing a tensor with an object of type LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.
is the error. And it is not flexible for changing N. Is there any more direct way? How can I solve it?
one way to solve it is to flatten the tensor, take the maximum index along dimension 0 and then unravel the index, either using numpy or your own logic (probably you will come up with something less clumsy ):
rawmaxidx = mytensor.view(-1).min(0)[1]
idx = []
for adim in list(mytensor.size())[::-1]:
idx.append(rawmaxidx%adim)
rawmaxidx = rawmaxidx / adim
idx = torch.cat(idx)
(Note that pytorch / on LongTensor is similar to python2 / or python3 // for ints.
Based on above question related to indexes, i have a question too.
I get max and index for the max, from 1xn vector containing LongTensor values. index obtained is also of LongTensor type. I want to further use this to obtain a value from dictionary.
dict = {1: ‘hello’ , 2: ‘world’}
how to mention the int in dict using LongTensor.
dict[index] gives error.
Please let me know