You could calculate the topk
of the flattened tensor and use this implementation to unravel the indices:
def unravel_index(index, shape):
out = []
for dim in reversed(shape):
out.append(index % dim)
index = index // dim
return tuple(reversed(out))
x = torch.randn(2, 3, 4, 5)
res = torch.topk(x.view(-1), k=3)
idx = unravel_index(res.indices, x.size())
print(x[idx] == res.values)
> tensor([True, True, True])