After spending some more time with the topic, I think @Sau_Tak was pretty close to what I needed. The only thing I had to add was some reshaping before and after to make it work for my use-case in arbitrary dimensions.
Note that now the for loop might be a little bit slow. which I think has been addressed by other posts. But for me that works fine now.
import torch, skimage
def gather_nd(params, indices):
'''
4D example
params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional
indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices
returns: tensor shaped [m_1, m_2, m_3, m_4]
ND_example
params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor
indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices
returns: tensor shaped [m_1, ..., m_1]
'''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
## Example
image_t = torch.from_numpy(skimage.data.astronaut())
params = ptcompat.torch_tile_nd(image_t.view((1, 1, 512, 512, 3)), [4, 16, 1, 1, 1]) # batch of stack of images
indices = torch.stack(torch.meshgrid(torch.arange(4), torch.arange(16), torch.arange(128), torch.arange(128), torch.arange(3)), dim=-1) # get 128 x 128 image slice from each item
out = gather_nd(params, indices)
print(out.shape) # >> [4, 16, 128, 128, 3]
plt.imshow(out[0, 0, ...])