Torch.compile() causes RuntimeError: element 0 of tensors does not require

Hi,

In the second torch.autograd.grad() for second derivative (d2y_dx2 in the below example) I got the RuntimeError:element 0 of tensors does not require grad and does not have a grad_fn.
I don’t get what the error means.The error happens only when torch.compile() is applied to ScalarNet. Any explanation will be appreciated!

input_scalar = torch.randn(1, device=device, requires_grad=True)


# Define the module
class ScalarNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = torch.nn.Sequential(
            torch.nn.Linear(1, 5), torch.nn.Tanh(), torch.nn.Linear(5, 1)
        )

    def forward(self, x):
        # Ensure input has feature dim
        if x.ndim == 1:
            x = x.unsqueeze(-1)
        return self.seq(x).squeeze(-1)  # Squeeze output back to scalar per batch item


net_scalar = ScalarNet().to(device)

# Compile the module
try:
    print("Compiling scalar model...")
    compiled_net_scalar = torch.compile(net_scalar)
    print("Compilation successful.")
except Exception as e:
    print(f"Scalar model compilation failed: {e}")
    compiled_net_scalar = net_scalar  # Fallback

# Compute derivatives using the compiled module
try:
    # First derivative
    y = compiled_net_scalar(input_scalar)
    # Ensure grad_outputs matches y shape (should be (1,) if input_scalar is (1,))
    grad_output_y = torch.ones_like(y)
    dy_dx = torch.autograd.grad(
        y, input_scalar, grad_outputs=grad_output_y, create_graph=True
    )[0]
    print("First derivative computed.")

    # Second derivative
    # Ensure grad_outputs matches dy_dx shape (should be (1,))
    grad_output_dydx = torch.ones_like(dy_dx)
    # Use retain_graph=False if this is the last grad call needing this graph segment
    d2y_dx2 = torch.autograd.grad(
        dy_dx, input_scalar, grad_outputs=grad_output_dydx, create_graph=False
    )[0]
    print("Second derivative computed successfully!")
    print(f"d2y/dx2 = {d2y_dx2.item()}")

except Exception as e:
    print(f"\n--- ERROR during second derivative test ---")
    print(e)
    import traceback

    traceback.print_exc()

Never mind. I can solve the problem by adding fullgraph=True.