Hello, I am trying to implement an inplace custom op below, and I want to use torch.compile
to generate a fx graph of this custom op.
I followed another guide, how to teach functionalization about a custom, mutable operator · GitHub. Currently, I can correctly generate an fx graph of a out-of-place custom op when invoking the in-place custom op in module. And the return value of the full module is correct. However, it seems that the result of the out-of-place custom op is not updated to the original input tensor of in-place custom op.
The model looks like this.
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, dst, index, src, dim):
torch.custom.foo_(dst, index, src, dim)
res = dst + 100
return res
The generated fx graph looks like this.
graph: graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%foo : [num_users=1] = call_function[target=torch.ops.custom.foo.default](args = (%arg0_1, %arg1_1, %arg2_1, -2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%foo, 100), kwargs = {})
return (add,)
This FX graph is not up to my expectations. In my opinion, to refresh the output of out-of-place op to the input of in-place op, we need to at least return the output of out-of-place op in fx graph , and then dynamo can complete the final copy_ .
Maybe fx graph should be like this?
graph: graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%foo : [num_users=1] = call_function[target=torch.ops.custom.foo.default](args = (%arg0_1, %arg1_1, %arg2_1, -2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%foo, 100), kwargs = {})
return (foo, add,)
What additional operations do I need to do to correctly calculate the in-place op in dynamo? I would appreciate any help.