Xuyang_Bai
(Xuyang Bai)
December 22, 2019, 3:07am
1
I was reimplementing a tensorflow code in pytorch but found that there is no corresponding functions for tf.batch_gather
, like the following code,
new_neighbors_indices = tf.batch_gather(neighbors_indices, inds)
The dimension of neighbors_indices
is [a, b] and the dimension of inds
is [a, c]. Is there any ways to implement such operations in pytorch?
albanD
(Alban D)
December 22, 2019, 1:26pm
2
Why not use the gather
function?
If it does not match what you want, please give a full example of what you want as input and output !
Xuyang_Bai
(Xuyang Bai)
December 23, 2019, 3:00am
3
Hi, Thanks for your reply! I want to change the following two lines into pytorch
new_neighbors_indices = tf.batch_gather(neighbors_indices, new_neigh_inds)
new_sq_distances = tf.batch_gather(sq_distances, new_neighb_inds)
Here the neighbors_indices
is dimension of [a, b] and new_neigh_inds
is dimension of [a, c], I find that for the first line I can use torch.gather
as,
new_neighbors_indices = torch.gather(neighbors_indices, dim=1, index=new_neighb_inds)
But the second line sq_distances
is dimension of [a, b, k] and new_neigh_inds
is dimension of [a, c], the expected results is [a, c, k] ,so how could I use torch.gather
? I tried torch.gather(sq_distances, dim=1, index=new_neighb_inds)
but get the error
*** RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/generic/THCTensorScatterGather.cu:16
albanD
(Alban D)
December 23, 2019, 9:55am
4
You want to make sure the indices are the same size as the values: new_neigh_inds.unsqueeze(-1).expand_as(sq_distances)
.
Xuyang_Bai
(Xuyang Bai)
December 23, 2019, 10:19am
5
I have tried and it works ! Thanks a lot.