Hi,
I have a question about dead-code elimination and autograd.
I am looking at the functional API, which has torch.func.grad and torch.func.grad_and_value.
Clearly torch.func.grad_and_value needs to compute both the forward and backward pass.
However, it seems like in some very simple cases torch.func.grad should be able to avoid computing the forward pass. For instance when computing the gradient $dE/dz$ of a quadratic term like $E=x^\top W z$, where $W=A^\top + A$ is a symmetric matrix, then I would like torch to be smart enough to compute $dE/dz = Wx$ without also computing E under the hood.
My question is if torch will avoid the redundant computation of E in such cases?
Do I need to compile the code to get this kind of optimized behaviour?