Hi,
I have a simple beginer question about autograd.
Thanks for the nice library.
I really like the functionality and I understand that it simplifies a lot of code by packaging the derivative against a variable directly in the variable.
What I am fearing is that the .backward() => .grad mechanic relies on the user keeping in mind that there is a mutable internal state, .grad, that is linked with what previous code was executed, which tends to hurt readability.
Let’s assume I write this:
>>> w = torch.tensor([1., 1., 1.], requires_grad=True)
>>> func = w[0] ** 2 + 2 * w[1] + w[2] # The loss function AST is computed
>>> func.backward() # the derivatives of func against each tensor that was declared with required_true is computed and stored in the respective .grad internal states
>>> w.grad # I can get the derivatives, handy
tensor([2., 2., 1.])
Now this means I should not write too much code between the .backward() and the .grad, because it is unclear which function it is the derivative of.
When I look in a codebase at a .grad, I do not know immediately what function was used to generate it.
>>> w = torch.tensor([1., 1., 1.], requires_grad=True)
>>> func = w[0] ** 2 + 2 * w[1] + w[2]
>>> func.backward()
>>> oups_had_forgotten = func + 3 * w[1]
>>> ... some code ...
>>> w.grad
tensor([2., 2., 1.])
If I have not been cautious about my backward, am I using the function func or oups here?
Also, if I wanted to compute another derivative, I need to .zero_() the tensor, which is another hard to read step.
I feel a more functional approach (in appearance, the backend could still behave like now) could be beneficial to try and learn users.
Consider something like:
>>> w.grad(func)
tensor([2., 2., 1.])
>>> w.grad(oups_had_forgotten)
tensor([2., 5., 1.])
This functional .grad would have the advantage of explicity and would not store the data in the tensor in place.
So, no need for .backward() and no need for .zero_() to clean the user exposed state.
Is there a function exposing such a behavior today? Would there be big downturns in terms of usability/performance?
Thank you