torch.gather could help in the following case,
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
However, what we do when we have multiple index tensors?
source = torch.ones(dim1, dim2, dim3, dim4)
index1 = torch.ones(dim3, dim4)
index2 = torch.ones(dim3, dim4)
for i in range(dim3):
for j in range(dim4):
ret[i, j] = source[index1[i, j], index2[i, j],i, j]