How to do a unravel_index in pytorch just like in numpy

Here is a generalized solution for any number of dimensions:

import torch


def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))

x = torch.arange(30).view(10, 3)
for i in range(x.numel()):
    assert i == x[unravel_index(i, x.shape)]
3 Likes