Dispatching of advanced indexing

Hello,

I have a backend and have registered some operations using PrivateUse1 dispatch key. I am trying to dispatch the following:

x[0:1, :, :].copy_(y)
return x + z

I get three slices dispatched to handle the indexing and each return a new tensor and the copy_ copies y into the tensor corresponding with the last dispatched slice but the add will take the sum of the original x and z.

What I expect is for the slices to be done in-place and the copy_ to be updating the original x value. I have tried registering index, index_put_ and masked_fill_ but neither of these ops will be dispatched.

Is there some other operation I should be registering for dispatch?

@almeetb
https://pytorch.org/cppdocs/notes/tensor_indexing.html
This might help

Thanks! based on this it seems the getter operations of slices are correct but I never get any setter ops dispatched.

@almeetb
You have your own dispatchKey. It seems to me that registering setter and getter are the same process.

Please create an issue in pytorch github. You need to provide the follow info:

  1. pytorch versions you are working on.
  2. Sample code on how you add dispatchKey
  3. Sample code on how you register ops
  4. Errors or logs you have

assign label ‘dispatch’ to the issue you created, we will have people answer your question there. thx

1 Like

I think the problem I am having is why I see certain operations being dispatched and not others. I don’t think I have an issue with how the registration is being done. If possible would you be able to run the sample code from above and verify which ops get dispatched. I am using pytorch 1.5

@almeetb
For your example above.
If you translate that into c++ apis, it should call index().
If you just do it from python, it won’t dispatch any of the indexing apis you’ve tried including index(), what it actually called is
https://github.com/pytorch/pytorch/blob/3d46e02ea1f8951bcef167c33bad0217ec9da980/torch/csrc/autograd/python_variable_indexing.cpp#L268

Thanks for your help! tracing through that showed me what my problem was