AoT compiler error: a leaf Variable that requires grad is being used in an in-place operation

I’m experimenting with the AoT compiler functorch.compile (experimental) — functorch 0.2.1 documentation and I’m trying to see if it can capture the forwards, backwards, and optimizer graphs. (pytorch v1.12 on cpu)

My code is:

import torch
import torch.nn as nn
from functorch.compile import aot_function


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.Linear(32, 32)
        self.relu = nn.ReLU()
        self.loss = nn.MSELoss()
        
    def forward(self, x):
        x = self.relu(self.ln1(x))
        return x
    
model = MyModule().train()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# This will print the full foward backwards graph!
def training_iter_fn(x, y):
    pred = model(x)
    loss = criterion(pred, y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return loss

def compiler_fn(fx_module: torch.fx.GraphModule, example_inputs):
    print(fx_module.code)
    print(example_inputs)
    return fx_module

# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(training_iter_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

x = torch.randn(1, 32)
y = x**2 / 5.
loss = training_iter_fn(x, y)
print(loss)

It works with just the forwards and backwards. However when I include the optimizer step/zero_grad it successfully prints the graph but then throws and exception of:

File <eval_with_key>.111:29, in forward(self, arg0_1, arg1_1)
     27 empty = torch.ops.aten.empty.memory_format([280], dtype = torch.uint8, device = device(type='cpu'))
     28 _param_constant0_1 = self._param_constant0
---> 29 add__2 = torch.ops.aten.add_.Tensor(_param_constant0_1, add__1, alpha = -0.0001);  _param_constant0_1 = None
     30 _param_constant1_1 = self._param_constant1
     31 add__3 = torch.ops.aten.add_.Tensor(_param_constant1_1, add_, alpha = -0.0001);  _param_constant1_1 = None

File ~/local/anaconda3/envs/dynamo/lib/python3.8/site-packages/torch/_ops.py:257, in OpOverload.__call__(self, *args, **kwargs)
    256 def __call__(self, *args, **kwargs):
--> 257     return self._op(*args, **kwargs or {})

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

The add__2 node looks like the weight update in the optimizer.

Anyone have an idea why this is being thrown? Thanks. :slight_smile: