I want to avoid loops here so I prepared an index tensor like this.
So the context is there are batches of input images(in).
In index tensor(ind), the channel of the target image is stored in 0 dim and where to sample pixels from that target image for corresponding output(out) pixel is stored in 1,2 dim.