Hi all,
Is there any equivalent functionality with tf.gather (https://www.tensorflow.org/api_docs/python/tf/gather) in pytorch?
The pytorch function with the same name seems to do something else. I tried to search the documentation but I was unable to find anything similar.
Thanks a lot!
If you want to index a single dimension, you can use index_select(). For more dimensions, you actually want to use torch.gather() but it is trickier to use.
index_select()
torch.gather()
Yes. torch.gather() is different. but I think you just can use direct indexing of pytorch.