Torch compile with forward-mode automatic differentiation

Hello,
I tried using the beta forward-mode automatic differentiation, but I ran into an issue when trying to compile my forward pass. I wonder if it was an error on my part or if it was an actual bug in the PyTorch code. Here is a minimal example that generates the error.

import torch
import torch.nn as nn
import torch.autograd.forward_ad as fwAD

device = "cuda:0"

@torch.compile()
def step(model, x, device):
    x = x.to(device)
    tangent = torch.zeros_like(x, device=device)
    with fwAD.dual_level():
        dual_input = fwAD.make_dual(x, tangent)
        dual_output = model(dual_input)
    return None

lamb = 0.001  # Regularization parameter

x = torch.randn(2,3)
model = nn.Sequential(nn.Linear(3,3),
                      nn.BatchNorm1d(3),
                      nn.Linear(3,1)).to(device)
model.train()


step(model, x, device)

I get an error like such:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function batch_norm at 0x788edb8ac9a0>(*(FakeTensor(..., device='cuda:0', size=(2, 3), grad_fn=<AddmmBackward0>,
           tangent=FakeTensor(..., device='cuda:0', size=(2, 3), grad_fn=<AddBackward0>)), FakeTensor(..., device='cuda:0', size=(3,)), FakeTensor(..., device='cuda:0', size=(3,)), FakeTensor(..., device='cuda:0', size=(3,), requires_grad=True), FakeTensor(..., device='cuda:0', size=(3,), requires_grad=True), True, 0.1, 1e-05), **{}):
InferenceMode::is_enabled() && self.is_inference() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/VariableMethodStubs.cpp":66, please report a bug to PyTorch. Expected this method to only be reached in inference mode and when all the inputs are inference tensors. You should NOT call this method directly as native::_fw_primal. Please use the dispatcher, i.e., at::_fw_primal. Please file an issue if you come across this error otherwise.

from user code:
   File "/home/enzo/Documents/git/LieEquiv/minimal_example.py", line 13, in step
    dual_output = model(dual_input)
  File "/home/enzo/mambaforge/envs/up_to_date/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/home/enzo/mambaforge/envs/up_to_date/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py", line 193, in forward
    return F.batch_norm(

Would you mind creating an issue on GitHub if the code still fails in the latest nightly, please?

Dear ptrblck,
Thank you for your answer and your contributions to this community.
The code does indeed fail with the latest nightly. I have created an issue on github