Hello,
I tried using the beta forward-mode automatic differentiation, but I ran into an issue when trying to compile my forward pass. I wonder if it was an error on my part or if it was an actual bug in the PyTorch code. Here is a minimal example that generates the error.
import torch
import torch.nn as nn
import torch.autograd.forward_ad as fwAD
device = "cuda:0"
@torch.compile()
def step(model, x, device):
x = x.to(device)
tangent = torch.zeros_like(x, device=device)
with fwAD.dual_level():
dual_input = fwAD.make_dual(x, tangent)
dual_output = model(dual_input)
return None
lamb = 0.001 # Regularization parameter
x = torch.randn(2,3)
model = nn.Sequential(nn.Linear(3,3),
nn.BatchNorm1d(3),
nn.Linear(3,1)).to(device)
model.train()
step(model, x, device)
I get an error like such:
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function batch_norm at 0x788edb8ac9a0>(*(FakeTensor(..., device='cuda:0', size=(2, 3), grad_fn=<AddmmBackward0>,
tangent=FakeTensor(..., device='cuda:0', size=(2, 3), grad_fn=<AddBackward0>)), FakeTensor(..., device='cuda:0', size=(3,)), FakeTensor(..., device='cuda:0', size=(3,)), FakeTensor(..., device='cuda:0', size=(3,), requires_grad=True), FakeTensor(..., device='cuda:0', size=(3,), requires_grad=True), True, 0.1, 1e-05), **{}):
InferenceMode::is_enabled() && self.is_inference() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/VariableMethodStubs.cpp":66, please report a bug to PyTorch. Expected this method to only be reached in inference mode and when all the inputs are inference tensors. You should NOT call this method directly as native::_fw_primal. Please use the dispatcher, i.e., at::_fw_primal. Please file an issue if you come across this error otherwise.
from user code:
File "/home/enzo/Documents/git/LieEquiv/minimal_example.py", line 13, in step
dual_output = model(dual_input)
File "/home/enzo/mambaforge/envs/up_to_date/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
File "/home/enzo/mambaforge/envs/up_to_date/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py", line 193, in forward
return F.batch_norm(