I am glad it helped.
check that out first to understand gather_nd explanation of gather_nd from stackoverflow
So basically gather_nd needs set of indexes(children in our case) to use,
In our case children have shape of [15, 720, 19, 2]
so you can think it as, there are 15*720*19
index tuples with two elements, lets say one of those tuples is equal to (19,22), it corresponds to lookup[19,22,:] —> this has shape of [1,1,30]
but as I said there 15*720*19
of those tuples, when you combine them all, you get[15,720,19,30]
.
When you translate that to pytorch, all those index tuples corresponds to :
children[:,:,:,0],children[:,:,:,1]
you need to use them as indexing elements to lookup, but you do not have an index for last dimension so you get all of them:
lookup[children[:,:,:,0],children[:,:,:,1],:]
I hope that makes it clearer.