How to do a unravel_index in pytorch just like in numpy

how to converts a flat index or array of flat indices into a tuple of coordinate arrays.
here is a example in numpy

np.unravel_index([22, 41, 37], (7,6))
(array([3, 6, 6]), array([4, 5, 1]))

can anyone help me ? thank you !

1 Like

Continuing the discussion from How to do a unravel_index in pytorch just like in numpy:

AFAIK unravel_index is basically converting a 1D index into its corresponding 2D version.

The general formula for this is:

x = index / ncols # (integer division in pytorch tensors is just `/` not `//`)
y = index % ncols

So you can get what you want using this code:

index = torch.tensor([22, 41, 37])
rows = index / 6
cols = index % 6

print(rows)
print(cols)
tensor([ 3,  6,  6])
tensor([ 4,  5,  1])
4 Likes

Here is a generalized solution for any number of dimensions:

import torch


def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))

x = torch.arange(30).view(10, 3)
for i in range(x.numel()):
    assert i == x[unravel_index(i, x.shape)]
3 Likes

So, what is the correct way to get torch’s np.unravel_index?

1 Like

I’ve written a more general (and efficient) unravel_index function, which I have posted here:

I’m quite new to the PyTorch community so I don’t really know how I should (or even if I should) do a PR for that.

Since you are the king of this forum, @ptrblck maybe do you know ?

1 Like

Thanks for sharing your approach! Let’s continue the discussion in the feature request.

Haha, history wasn’t always nice to monarchs, so I hope I’m more the first layer of defense against nasty bugs. :wink:

1 Like