I want to understand how autograd works a little bit on the backend, but I have no idea how the ctx variable gets passed for custom Functions. I understand that it’s the backward object that correlates to a Function, but where in the source code are you guys passing the ctx to the forward() and backward() methods so that it’s accessible to users?
Example:
import torch
from torch.autograd import Function
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
# ctx is a context object that can be used to stash information
# for backward computation
tensor, constant = inputs
print(ctx, type(ctx), ctx.__class__.__name__, sep="\n")
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
def mul_constant(tensor, c=1):
return MulConstant.apply(tensor, c)
tensor = torch.ones((5, 1), requires_grad=True)
result = mul_constant(tensor, c=10)
Output:
<torch.autograd.function.MulConstantBackward object at 0x114199300>
<class 'torch.autograd.function.MulConstantBackward'>
MulConstantBackward
Function.apply creates the ctx as an instance of the backward node class, this is relatively deep in the C++ guts of the autograd engine, below is the C++ implementation of Function.apply.
I used to offer an “All about autograd” course, but sadly, I have not updated it to PT2 yet, so it is missing AOTAutograd other things that came after 2021.
This was pretty helpful. It just peaked my curiosity because I just started learning about autograd and was just really confused on where exactly it originated from. Honestly, I’d still like to check out some parts of your course since I want to learn about eager mode autograd. Do you have a link to it so I can check out some of it?