Get top k indices/values of all rows

You could calculate the topk of the flattened tensor and use this implementation to unravel the indices:

def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))


x = torch.randn(2, 3, 4, 5)
res = torch.topk(x.view(-1), k=3)

idx = unravel_index(res.indices, x.size())
print(x[idx] == res.values)
> tensor([True, True, True])