Why does this function break the gradient tree? 🤔

Hi! Thanks for getting back to me!

I can give that a try with my usecase and see what happens! It might not work as in actual fact I have (1,120,64,64) cube inputs/outputs but it’s worth a shot :slight_smile:

The difficulty here is that the network is specifically designed to learn the indices as a 2D tensor before blowing up to the full cube…I’m not sure how to get around this without indexing being incurred as a result (eventually) because the step from learned parameters to the index tensor is an analytical one rather than a learned one if that makes sense…

Perhaps there is a way of combining scatter with grid_sample like in this post? (Although I see in the docs that it only supports spatial indexing, so 2D idexing)