Autograd and dead-code elimination

Hi,
I have a question about dead-code elimination and autograd.

I am looking at the functional API, which has torch.func.grad and torch.func.grad_and_value.
Clearly torch.func.grad_and_value needs to compute both the forward and backward pass.
However, it seems like in some very simple cases torch.func.grad should be able to avoid computing the forward pass. For instance when computing the gradient $dE/dz$ of a quadratic term like $E=x^\top W z$, where $W=A^\top + A$ is a symmetric matrix, then I would like torch to be smart enough to compute $dE/dz = Wx$ without also computing E under the hood.

My question is if torch will avoid the redundant computation of E in such cases?
Do I need to compile the code to get this kind of optimized behaviour?

Hey!

torch.func by itself is indeed an eager only implementation. And thus will not do any DCE.

You will have to use torch.compile and similar compiler structure to get that behavior.

Thanks,
Great to hear that the compiler can do this :slight_smile: