I’m experimenting with the AoT compiler functorch.compile (experimental) — functorch 0.2.1 documentation and I’m trying to see if it can capture the forwards, backwards, and optimizer graphs. (pytorch v1.12 on cpu)
My code is:
import torch
import torch.nn as nn
from functorch.compile import aot_function
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.Linear(32, 32)
self.relu = nn.ReLU()
self.loss = nn.MSELoss()
def forward(self, x):
x = self.relu(self.ln1(x))
return x
model = MyModule().train()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
# This will print the full foward backwards graph!
def training_iter_fn(x, y):
pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def compiler_fn(fx_module: torch.fx.GraphModule, example_inputs):
print(fx_module.code)
print(example_inputs)
return fx_module
# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(training_iter_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
x = torch.randn(1, 32)
y = x**2 / 5.
loss = training_iter_fn(x, y)
print(loss)
It works with just the forwards and backwards. However when I include the optimizer step/zero_grad it successfully prints the graph but then throws and exception of:
File <eval_with_key>.111:29, in forward(self, arg0_1, arg1_1)
27 empty = torch.ops.aten.empty.memory_format([280], dtype = torch.uint8, device = device(type='cpu'))
28 _param_constant0_1 = self._param_constant0
---> 29 add__2 = torch.ops.aten.add_.Tensor(_param_constant0_1, add__1, alpha = -0.0001); _param_constant0_1 = None
30 _param_constant1_1 = self._param_constant1
31 add__3 = torch.ops.aten.add_.Tensor(_param_constant1_1, add_, alpha = -0.0001); _param_constant1_1 = None
File ~/local/anaconda3/envs/dynamo/lib/python3.8/site-packages/torch/_ops.py:257, in OpOverload.__call__(self, *args, **kwargs)
256 def __call__(self, *args, **kwargs):
--> 257 return self._op(*args, **kwargs or {})
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
The add__2
node looks like the weight update in the optimizer.
Anyone have an idea why this is being thrown? Thanks.