How to perform tf.gather_n in Pytorch?


What is the possible way to perform the gradient on the index?

I am trying to do something like “spatial transform network”.

However, I need to do the sample on a branch of input so I want to try this by my own.

something like:

# build the index from range
x,y,u,v = meshgrid(...)

# select the pixel from indexs


This will done by tensorflow as:


however, I get the error on pytorch that Gather can't differentiate the index

any ideas about this problem?

thanks in advanced.

Thanks for your reply.

I need to do the transformation in high dimensional so i want to implement my own grid_sample.