How to define an in-place custom op in dynamo?

Hello, I am trying to implement an inplace custom op below, and I want to use torch.compile to generate a fx graph of this custom op.

I followed another guide, how to teach functionalization about a custom, mutable operator · GitHub. Currently, I can correctly generate an fx graph of a out-of-place custom op when invoking the in-place custom op in module. And the return value of the full module is correct. However, it seems that the result of the out-of-place custom op is not updated to the original input tensor of in-place custom op.

The model looks like this.

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, dst, index, src, dim):
        torch.custom.foo_(dst, index, src, dim)
        res = dst + 100 
        return res

The generated fx graph looks like this.

graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %foo : [num_users=1] = call_function[target=torch.ops.custom.foo.default](args = (%arg0_1, %arg1_1, %arg2_1, -2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%foo, 100), kwargs = {})
    return (add,)

This FX graph is not up to my expectations. In my opinion, to refresh the output of out-of-place op to the input of in-place op, we need to at least return the output of out-of-place op in fx graph , and then dynamo can complete the final copy_ .

Maybe fx graph should be like this?

graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %foo : [num_users=1] = call_function[target=torch.ops.custom.foo.default](args = (%arg0_1, %arg1_1, %arg2_1, -2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%foo, 100), kwargs = {})
    return (foo, add,)

What additional operations do I need to do to correctly calculate the in-place op in dynamo? I would appreciate any help.

This should work, as long as you’re correctly registering your custom op and the corresponding functionalization rule to the dispatcher. Can you link your repro code, including the code to register your custom operator?

When using torch.compile, functionalization will run, and (if implemented properly for your op) it will convert foo_() nodes in the graph into foo() (out of place variant).

You’re correct that (during training), we’d expect any graph inputs that were mutated by foo_() to be returned as additional outputs in the graph.

I’m very grateful for your answer. @bdhirsh

Now I run custom.foo_() in eager mode, I can get the correct calculation result and modify the input in-place. Based on your guidance, perhaps the problem is mainly in the functionalization? Here are my modifications.

## regist custom op to  custom function yaml
  - func: foo(Tensor self, Tensor index, Tensor src, int dim) -> Tensor
  - func: foo_(Tensor(a!) self, Tensor index, Tensor src, int dim) -> Tensor(a!)


## add impl of custom op
at::Tensor foo(const at::Tensor &self, const at::Tensor &index, const at::Tensor &src, int64_t dim) {
  return self.mul(2);
}

at::Tensor &foo_(at::Tensor &self, const at::Tensor &index, const at::Tensor &src, int64_t dim) {
  return self.mul_(2);
}

## add functionalization
at::Tensor &foo__functionalization_glue(at::Tensor &self, const at::Tensor &index, const at::Tensor &src,
                                        int64_t dim){
  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(index));
  TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
  at::functionalization::impl::sync(self);
  at::functionalization::impl::sync(index);
  at::functionalization::impl::sync(src);
  auto x_0 = at::functionalization::impl::from_functional_tensor(self);
  auto x_1 = at::functionalization::impl::from_functional_tensor(index);
  auto x_2 = at::functionalization::impl::from_functional_tensor(src);

  static auto op_handle = c10::Dispatcher::singleton()
      .findSchemaOrThrow("custom::foo", "")
      .typed<at::Tensor(const at::Tensor&, const at::Tensor&, const at::Tensor&, int64_t)>();

  at::Tensor tmp_output;
  {
    at::AutoDispatchSkipFunctionalize guard;
    tmp_output = op_handle.call(x_0, x_1, x_2, dim);
  }
  at::functionalization::impl::replace_(self, tmp_output);
  at::functionalization::impl::commit_update(self);
  at::functionalization::impl::sync(self);
  return self;
}

TORCH_LIBRARY_IMPL(custom, Functionalize, m) {
  m.impl("foo_", foo__functionalization_glue);
}