Hi,
one way to solve it is to flatten the tensor, take the maximum index along dimension 0 and then unravel the index, either using numpy or your own logic (probably you will come up with something less clumsy ):
rawmaxidx = mytensor.view(-1).min(0)[1]
idx = []
for adim in list(mytensor.size())[::-1]:
idx.append(rawmaxidx%adim)
rawmaxidx = rawmaxidx / adim
idx = torch.cat(idx)
(Note that pytorch /
on LongTensor is similar to python2 /
or python3 //
for ints.
Best regards
Thomas