Tensor.requires_grad in torch.autograd.Function

I attempted to implement a custom function as outlined in the tutorial:

# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        print('inputs.requires_grad: ', input.requires_grad)
        print('weight.requiers_grad: ', weight.requires_grad)
        print('output.requires_grad: ', output.requires_grad)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

x = torch.rand(2, 128, requires_grad=True)
weight = torch.rand(256, 128, requires_grad=True)
bias = troch.rand(256, requires_grad=True)
out = LinearFunction.apply(x, weight, bias)
print('out.requiers_grad: ', out.requiers_grad)

Upon executing the above code, the following outputs were obtained:

inputs.requires_grad: True
weight.requires_grad: True
output.requires_grad: False
out.requires_grad: True

The issue arises from the operation output = input.mm(weight.t()), generating a tensor without gradients despite both input and weight tensors having gradients enabled.

I assume that the absence of gradients in the output might be due to the forward function being invoked like:

def some_func():
    with torch.no_grad:

If my assumption holds, I have three additional questions:

  1. Where can I find the definition of the aforementioned some_func()?

  2. Why the output.requires_grad = False in the forward function, but out.requires_grad=True outside? What additional operation is applied after executing the forward() function?

  3. What is the purpose of such a design? Is it intended to conserve GPU memory?

  4. As far as I know, there is a technique called Checkpoint employed to optimize memory usage, which is detailed in torch.utils.checkpoint. It involves the omission of gradients during the forward process, with a subsequent re-computation during the backward process. I am curious about the distinctions and connections between torch.autograd.Function and torch.utils.checkpoint, particularly regarding their treatment of gradients in the forward process.

Additionally, I have reviewed the implementation of torch.utils.checkpoint:

class CheckpointFunction(torch.autograd.Function):
    def forward(ctx, run_function, preserve_rng_state, *args):
    def backward(ctx, *args):

This class, CheckpointFunction, is a subclass of torch.autograd.Function. I have two additional questions regarding the checkpoint:

  1. What techniques does the CheckpointFunction class employ to conserve GPU memory? Is it primarily achieved by omitting gradients in the forward process (though the original torch.autograd.Function also seems to run in a no_grad() environment)?

  2. In the forward function of the CheckpointFunction class, there are the following codes:

    with torch.no_grad():
        outputs = run_function(*args)
    return outputs

Is it redundant to define a new no_grad() environment? In reference to my first question, it appears that the forward function in the torch.autograd.Function class has already been wrapped in the torch.no_grad() statement.


  1. The implementation is quite tricky and split between python and C++ I’m afraid.
  2. This is expected behavior: you are providing the backward formula for the ops happening here (via the backward() method) so we don’t need to track autograd information during the forward. When the output is being returned, we will attach the appropriate autograd metadata to it so that it requires_grad and it will call your custom backward().
  3. This is done mainly for performance: no need to create the autograd graph during the forward since we’re never going to use it. There are other more subtle implications like not saving things for backward twice etc that are also important.
  4. checkpoint will omit intermediary buffers (save_for_backward), not gradients. And they are very much independent. As you can see in the utils.checkpoint code, we are transitioning to not use autograd.Function to implement AC (Activation Checkpointing) anymore.
  5. It is by only saving for backward the input.
  6. It is indeed redundant.
1 Like

Hi @albanD, thanks for your detailed response. Most of my questions have been resolved.

I have one remaining question regarding your second answer: We will attach the appropriate autograd metadata to it so that it requires_grad

My question is: Is the function responsible for attaching the grad metadata also implemented in the C++ source code, like the aforementioned some_func()?

Looking forward to your reply :slight_smile:

Yes it is in c++
It is here exactly if you’re curious: https://github.com/pytorch/pytorch/blob/d9c0e37bab9462c18508a594659cd34a66abfe1e/torch/csrc/autograd/custom_function.cpp#L355-L362 :slight_smile: