Hi! I am interested in the exact meaning of grad_fn=. Thank you!!!
The simple explanation is: during the forward pass PyTorch will track the operations if one of the involved tensors requires gradients (i.e. its .requires_grad
attribute it set to True
) and will create a computation graph from these operations. To be able to backpropagate through this computation graph and to calculate the gradients for all involved parameters, PyTorch will additionally store the corresponding “gradient functions” (or “backward functions”) of the executed operations to the output tensor (stored as the .grad_fn
attribute). Once the forward pass is done, you can then call the .backward()
operation on the output (or loss) tensor, which will backpropagate through the computation graph using the functions stored in .grad_fn
.
In your case the output tensor was created by a torch.pow
operation and will thus have the PowBackward
function attached to its .grad_fn
attribute:
x = torch.randn(2, requires_grad=True)
out = torch.pow(x, 2)
print(out)
# tensor([0.4651, 1.1575], grad_fn=<PowBackward0>)
out.sum().backward()
print(x.grad)
# tensor([-1.3640, -2.1517])
Thank you very much for your detailed answer!!!
Can you point to the pytorch codebase where the grad_fns are implemented? functions like SumBackward0, PowBackward0 etc.
You can search the derivative of functions in the derivatives.yaml
file and will see for e.g. pow
:
- name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
self: pow_backward(grad, self, exponent)
result: auto_element_wise
Searching for pow_backward
points to this function.