AOT autograd functionalizes the views correctly only if at least one of the input tensors is having required_grad == True.
In pytorch/aot_autograd.py at master · pytorch/pytorch · GitHub
if needs_autograd:
compiler_fn = aot_dispatch_autograd → This functionalizes the views correctly
else:
compiler_fn = aot_dispatch_base → This is resulting in incorrect view handling
import torch
from functorch.compile import aot_function
def fn(a, b):
e = a[:]
e = a.as_strided([2, 2], [4, 2], 0) #[:,0:-1:2] #a.view([8])
e.add_(1.0)
a.mul_(2.0)
e[2:-1] = 3.0
return e[1:-1:2]
Test that it works
a = torch.randn(2, 4, requires_grad=False)
b = torch.randn(2, 4, requires_grad=True)
The compiler_fn is called after the forward and backward graphs are extracted.
Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
print(fx_module.code)
#print(fx_module.graph)
return fx_module
Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
Run the aot_print_fn once to trigger the compilation and print the graphs
res = aot_print_fn(a, b)
Output:
def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [2, 2], [4, 2], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1.0); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2, 2], [4, 2], 0); clone = add = None
mul = torch.ops.aten.mul.Tensor(as_strided_scatter, 2.0); as_strided_scatter = None
_tensor_constant0 = self._tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
as_strided_2 = torch.ops.aten.as_strided.default(mul, [2, 2], [4, 2], 0)
slice_3 = torch.ops.aten.slice.Tensor(as_strided_2, 0, 2, -1); as_strided_2 = None
fill = torch.ops.aten.fill.Tensor(slice_3, lift_fresh_copy); slice_3 = lift_fresh_copy = None
as_strided_3 = torch.ops.aten.as_strided.default(mul, [2, 2], [4, 2], 0)
slice_scatter = torch.ops.aten.slice_scatter.default(as_strided_3, fill, 0, 2, -1); as_strided_3 = fill = None
as_strided_scatter_1 = torch.ops.aten.as_strided_scatter.default(mul, slice_scatter, [2, 2], [4, 2], 0); mul = slice_scatter = None
return [as_strided_scatter_1, 0, 2, 8, 2, 4]
Now if we chnage b = torch.randn(2, 4, requires_grad=False) (True → False)
output becomes:
def forward(self, arg0_1, arg1_1):
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 9223372036854775807)
as_strided = torch.ops.aten.as_strided.default(arg0_1, [2, 2], [4, 2], 0)
add_ = torch.ops.aten.add_.Tensor(as_strided, 1.0); as_strided = None
mul_ = torch.ops.aten.mul_.Tensor(arg0_1, 2.0); arg0_1 = None
tensor_constant0 = self.tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(tensor_constant0); tensor_constant0 = None
slice_2 = torch.ops.aten.slice.Tensor(add, 0, 2, -1)
fill = torch.ops.aten.fill.Tensor(slice_2, lift_fresh_copy); slice_2 = lift_fresh_copy = None
slice_3 = torch.ops.aten.slice.Tensor(add, 0, 1, -1, 2); add_ = None
return [slice_3]
In the second case, the generate code is functionally incorrect
Question:
Why not make aot_dispatch_autograd the default/only option as it seem to functionalize views correctly?