Declare torch.compile compatible inplace operator

I developed an C++ operator which replaces index_add with a few modifications. Its first argument is modified in-place. It has the following signature:

  m.def("index_add(Tensor tensor1, int dim, Tensor indices, Tensor values) -> Tensor");

Where tensor1 is modified in-place and returned for convenience.

When using this operator, torch.compile complains that the in-place modified argument is not returned as an output. In reality it is, but it can’t know that from the signature.

RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function.

How do I declare that the modified tensor is indeed an output as well?

I tried to blindly use something I saw in the pytorch source on the op signature declaration:

m.def("index_add(Tensor tensor1(a!), int dim, Tensor indices, Tensor values) -> Tensor(a!)");

but it doesn’t seem to fix the problem.