How to use Tf.gather_nd in pytorch

I have 2 tensors that have shapes like this:

(15, 720, 30)
(15, 720, 19, 2)

I want to gather the tensors into one tensor with the shape like this : (15, 720, 19, 30), how can I do this?
From the other topic: How to do the tf.gather_nd in pytorch?, i see that there is a way but it’s tricky for me to understand how to use this.


What should your output matrix contain? I fail to see what is the operation that you want to do here.


Thanks for your reply.

If you look at this repo at this line:, which performs a bunch of complicated matrix operations on the tree structure.

These lines do the job:

vector_lookup = tf.concat([zero_vecs, nodes[:, 1:, :]], axis=1)

children = tf.concat([batch_indices, children], axis=3)
return tf.gather_nd(vector_lookup, children, name='children')

the shape of the vector lookup will be sth like: (15, 720, 30)
and the shape of the children will be sth like: (15, 720, 19, 2)

and my goal is to do thing similar to the return line
I’m trying to port this into Pytorch but seems quite tricky in some parts.

I do not know tensorflow very well, so I am not sure if I got what you are asking correctly, can you check if thats what you are looking for:

>>> lookup=torch.ones((15, 720, 30))
>>> children=torch.randint(0,15,(15, 720, 19, 2),dtype=torch.long)
>>> lookup[children[:,:,:,0],children[:,:,:,1],:].shape
torch.Size([15, 720, 19, 30])
1 Like

you’re the lifesaver !!!, thanks a lot

But I don’t understand how did you come up with the solution? Do you mind sharing how the code works?

The scenario can be described like this:
lookup=torch.ones((15, 720, 30))

means: batch_size x num_nodes x feature_size
where num nodes are the number of nodes in a tree and each node is represented by an embedding with size = feature_size.

children= (15, 720, 19, 2)
means: batch_size x num_nodes x num_children x 2. I don’t really understand the meaning of 2 here.

Since every node in a tree is represented by an embedding, what the tensorflow code did is that they want to “merge” the 2 tensors into one, the dimension 4 means that each child in dimension 3 contains a corresponding embedding of its parent, make it more convenient for the later step.

In your code, i don’t really get how it works but it did the job…

I am glad it helped.
check that out first to understand gather_nd explanation of gather_nd from stackoverflow

So basically gather_nd needs set of indexes(children in our case) to use,
In our case children have shape of [15, 720, 19, 2] so you can think it as, there are 15*720*19 index tuples with two elements, lets say one of those tuples is equal to (19,22), it corresponds to lookup[19,22,:] —> this has shape of [1,1,30] but as I said there 15*720*19 of those tuples, when you combine them all, you get[15,720,19,30].

When you translate that to pytorch, all those index tuples corresponds to :
you need to use them as indexing elements to lookup, but you do not have an index for last dimension so you get all of them:

I hope that makes it clearer.



it’s clearer for me now :slight_smile:

That’s smart! And thanks your answer really helps me.