Tf.batch_gather in PyTorch

I was reimplementing a tensorflow code in pytorch but found that there is no corresponding functions for tf.batch_gather, like the following code,

new_neighbors_indices = tf.batch_gather(neighbors_indices, inds)

The dimension of neighbors_indices is [a, b] and the dimension of inds is [a, c]. Is there any ways to implement such operations in pytorch?

Why not use the gather function?
If it does not match what you want, please give a full example of what you want as input and output !

Hi, Thanks for your reply! I want to change the following two lines into pytorch

new_neighbors_indices = tf.batch_gather(neighbors_indices, new_neigh_inds)
new_sq_distances = tf.batch_gather(sq_distances, new_neighb_inds)

Here the neighbors_indices is dimension of [a, b] and new_neigh_inds is dimension of [a, c], I find that for the first line I can use torch.gather as,

new_neighbors_indices = torch.gather(neighbors_indices, dim=1, index=new_neighb_inds)

But the second line sq_distances is dimension of [a, b, k] and new_neigh_inds is dimension of [a, c], the expected results is [a, c, k] ,so how could I use torch.gather? I tried torch.gather(sq_distances, dim=1, index=new_neighb_inds) but get the error

*** RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/generic/

You want to make sure the indices are the same size as the values: new_neigh_inds.unsqueeze(-1).expand_as(sq_distances).

I have tried and it works ! Thanks a lot.