Ctx.needs_input_grad behaviour

Hi!

I am following a tutorial on https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd.

Feeling very confused understanding ctx.needs_input_grad.

Example:

class FC(torch.nn.Module):
    def __init__(self):
        super(FC, self).__init__()
        self.fc1 = Linear(3,64,bias=True)
        self.fc2 = Linear(64,64,bias=True)
    
    def forward(self, x):        
        x = self.fc1(x)
        x = self.fc2(x)
        return x

Output:

ctx.needs_input_grad (True, True, True)
ctx.needs_input_grad (False, True, True)

Which is correct because first True is wx+b w.r.t. x and it takes part in a chain rule. It is reflected in the graph but is not reflected in model.parameters().
How PyTorch figures variable that takes part in a chain? Is it the first position?

I am taking the courage of calling @albanD

Thank you very much for your attention!

Andrei

1 Like

Hi,

Just to be sure, you modified the autograd.Function backward method in Linear to print this information?

A good way to see this is to rewrite the whole forward as just plain Functions:

# input that you gave the nn.Module
inp = torch.rand(10, 3) # Note that this ones does not require gradients
# Parameters for the first nn.Linear
fc1_weight = torch.rand(64, 3, requires_grad=True)
fc1_bias = torch.rand(64, requires_grad=True)
# Parameters for the second nn.Linear
fc2_weight = torch.rand(64, 64, requires_grad=True)
fc2_bias = torch.rand(64, requires_grad=True)

# What happens when you do FC()(input)
# I will write Function_Linear for the function that corresponds to the Linear layer
x = Function_Linear(inp, fc1_weight, fc1_bias)
# Note that x.requires_grad=True now
x = Function_Linear(x, fc1_weight, fc1_bias)

As you can see, for the first one, the tensors that are given, the first one does not require grads and the two others do. For the second one, all three require gradients.
So knowing which ones needs input gradient is a as simple as checking the requires_grad field for the input Tensors in the forward pass.
Now if you print during the backward, the backward of the second one is called first, then the backward of the first one.

Does that help?

Yes, I modified autograd.Function.

Yes, if inputs are torch.nn.Parameter 's it is clear that they will be True in ctx.needs_input_grad.

What about first True which wx+b w.r.t. x? It is figuring out two things:

  1. More work underway. There are layers in the back process.
  2. Select correct var which is x and take derivative w.r.t. to it to take part in a chain rule.
ctx.needs_input_grad (True, True, True) - more work to do, first True, which is grad_input in other words x.
ctx.needs_input_grad (False, True, True) - no more work to do, first False. No more backprop.

x requires_grad = False by default.

Yes, if inputs are torch.nn.Parameter 's it is clear that they will be True in ctx.needs_input_grad.

Note that nn.Parameter is only a tool in nn. From the autograd point of view, it’s only a Tensor with requires_grad=True.

The first linear’s output requires gradients. You can check print(x.requires_grad).
So the second linear knows that it needs to compute gradients for its input x. (because x.requires_grad == True).

The first linear knows that it does not need to compute gradients for it’s first input because inp.requires_grad==False.

If you question is how is the backpropagation implemented. The implementation detail here is that during the forward, a graph is built with everything that needs to be done during the backward.
You can see this very nicely using the torchviz package.

Thank you for clarification regarding requires_grad. It is really helpful.

I checked x(or input):

input.requires_grad False

The whole thing:

# Inherit from Function
class LinearFunction(torch.autograd.Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        print('input.requires_grad', input.requires_grad)
        ctx.save_for_backward(input, weight, bias)
        print('input forward', input)
        print('weight forward', weight)
        print('bias forward', bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        print('grad_output', grad_output)
        print('grad_output.shape', grad_output.shape)
        print('type(ctx)', type(ctx))
        input, weight, bias = ctx.saved_tensors
        print('ctx.saved_tensors', ctx.saved_tensors)
        print('input.shape', input.shape)
        print('weight.shape', weight.shape)
        print('bias.shape', bias.shape)
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        print('ctx.needs_input_grad', ctx.needs_input_grad)
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
            print('grad_input', grad_input)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
            print('grad_weight', grad_weight)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
            print('grad_bias', grad_bias)

        return grad_input, grad_weight, grad_bias
input.requires_grad False
input forward tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
weight forward Parameter containing:
tensor([[-5.7689e-02, -7.1829e-02, -2.6084e-02],
        [-2.0893e-03,  7.3108e-02, -9.8120e-02],
        [ 7.9525e-02, -9.9841e-02,  2.9965e-02],
        [ 5.8336e-02, -6.8684e-02,  8.0132e-02],
        [-7.8534e-02, -6.5824e-02,  6.9786e-02],
        [ 3.1624e-02, -4.9671e-02,  9.5418e-02],
        [ 5.4930e-02, -3.5600e-02, -7.9666e-02],
        [-1.9675e-02,  5.5125e-02,  2.1277e-02],
        [ 7.9899e-02, -6.3545e-03,  7.5164e-02],
        [ 6.5107e-02, -1.1355e-02,  6.1582e-02],
        [ 5.7898e-02,  8.2228e-02, -2.8934e-04],
        [ 3.8152e-02, -8.8827e-02, -9.7715e-02],
        [ 5.0026e-02, -5.9546e-02,  6.7430e-03],
        [-4.7775e-03, -7.8560e-02,  6.5737e-02],
        [ 6.6842e-02, -4.8822e-02, -2.9546e-02],
        [ 1.7129e-02, -1.9333e-02, -6.0500e-03],
        [ 2.3933e-02,  1.8603e-02, -7.6528e-02],
        [-8.5194e-02,  2.3240e-03, -8.4081e-02],
        [ 6.4385e-02, -9.2458e-02, -2.6505e-02],
        [ 9.1981e-02,  7.0836e-02,  3.5892e-02],
        [ 9.9067e-02,  3.2538e-02, -4.0784e-02],
        [-1.9370e-02, -2.8337e-02,  7.1533e-02],
        [ 8.7615e-02,  7.4066e-02,  8.4684e-03],
        [-9.8254e-03, -6.5946e-02,  7.2230e-02],
        [ 4.4700e-02,  7.9217e-02, -8.4470e-02],
        [-7.5376e-02, -6.5382e-02,  2.3836e-02],
        [ 5.7385e-02, -1.5403e-02,  3.6721e-02],
        [ 5.8476e-02, -5.9563e-02, -8.3188e-02],
        [ 7.3825e-02, -7.6582e-02, -3.1223e-04],
        [-4.3375e-03,  5.8998e-02, -6.2707e-03],
        [ 9.0194e-02, -6.9641e-02, -6.6580e-02],
        [ 3.5148e-02, -7.3449e-02, -3.4645e-02],
        [-5.3004e-02,  8.5227e-02,  6.0224e-02],
        [ 6.7936e-02, -5.7037e-02, -1.3519e-02],
        [-2.3542e-02, -6.3962e-02,  7.6364e-02],
        [-1.1940e-02,  5.3724e-03, -5.2544e-02],
        [-2.9361e-02,  2.4697e-02,  9.7657e-02],
        [ 3.1170e-02,  5.1035e-03, -2.3094e-02],
        [-1.0187e-02,  5.2062e-02,  2.5158e-02],
        [-3.7078e-02, -1.0134e-02,  3.8032e-03],
        [ 8.8694e-02, -3.0750e-03,  2.4333e-02],
        [ 9.2952e-02, -7.6131e-02,  6.7218e-02],
        [-3.4485e-02, -1.3299e-02, -6.3930e-02],
        [-4.7276e-02, -3.5576e-02,  7.2038e-02],
        [ 5.9960e-02,  2.2295e-02, -4.2306e-02],
        [-9.7502e-02,  5.1908e-02,  3.0460e-02],
        [ 7.1249e-02, -3.8853e-02,  7.7737e-02],
        [ 6.4047e-02,  1.7193e-02,  6.2109e-03],
        [ 1.6491e-02, -8.0544e-02,  2.6576e-02],
        [-6.6234e-02,  9.5245e-02,  8.1419e-02],
        [-8.5292e-02, -5.5366e-02, -1.9801e-02],
        [-2.5768e-02,  6.0976e-02, -9.0075e-03],
        [-6.0146e-02,  8.7594e-02,  3.5786e-02],
        [ 6.1822e-02,  9.8115e-02, -2.8956e-02],
        [ 5.5727e-02, -7.2281e-03, -3.3081e-02],
        [ 9.9295e-02, -2.2229e-02,  1.7998e-02],
        [ 5.3005e-02,  3.3666e-03, -4.3881e-02],
        [-5.8789e-02,  6.7245e-02,  9.3645e-02],
        [ 1.5318e-03,  4.7252e-05,  3.1342e-02],
        [-9.4592e-02,  1.6203e-03,  6.7871e-02],
        [-8.1662e-02,  4.9619e-02, -3.4235e-02],
        [-4.7592e-03,  8.8230e-02, -7.0480e-02],
        [-8.4769e-02, -4.5581e-02, -9.5608e-02],
        [ 5.2312e-02, -3.2134e-02, -9.9308e-02]], requires_grad=True)
bias forward Parameter containing:
tensor([ 0.0021, -0.0651,  0.0892, -0.0630,  0.0621,  0.0624,  0.0788, -0.0088,
         0.0874, -0.0747, -0.0461, -0.0674,  0.0427,  0.0846, -0.0922, -0.0300,
         0.0652,  0.0913,  0.0199, -0.0351,  0.0483,  0.0904,  0.0747, -0.0173,
        -0.0394,  0.0115,  0.0767,  0.0386,  0.0785,  0.0339, -0.0191,  0.0577,
        -0.0916,  0.0067,  0.0178,  0.0637, -0.0571, -0.0420,  0.0207, -0.0514,
        -0.0505, -0.0063, -0.0777,  0.0068,  0.0712,  0.0565,  0.0319, -0.0343,
        -0.0751,  0.0895,  0.0987,  0.0345, -0.0811, -0.0717, -0.0697, -0.0729,
        -0.0520,  0.0426, -0.0916, -0.0618, -0.0259,  0.0824, -0.0973,  0.0086],
       requires_grad=True)
input.requires_grad True
input forward tensor([[-0.1535, -0.0922,  0.0989,  ...,  0.0953, -0.3233, -0.0706],
        [-0.1535, -0.0922,  0.0989,  ...,  0.0953, -0.3233, -0.0706],
        [-0.1535, -0.0922,  0.0989,  ...,  0.0953, -0.3233, -0.0706],
        ...,
        [-0.1535, -0.0922,  0.0989,  ...,  0.0953, -0.3233, -0.0706],
        [-0.1535, -0.0922,  0.0989,  ...,  0.0953, -0.3233, -0.0706],
        [-0.1535, -0.0922,  0.0989,  ...,  0.0953, -0.3233, -0.0706]],
       grad_fn=<LinearFunctionBackward>)
weight forward Parameter containing:
tensor([[-0.0373, -0.0337, -0.0244,  ...,  0.0944, -0.0116,  0.0215],
        [ 0.0040, -0.0831,  0.0885,  ...,  0.0085, -0.0741,  0.0127],
        [ 0.0498, -0.0686,  0.0238,  ..., -0.0737, -0.0593, -0.0929],
        ...,
        [ 0.0007,  0.0510, -0.0142,  ...,  0.0279, -0.0415, -0.0512],
        [-0.0253, -0.0744, -0.0856,  ..., -0.0384, -0.0035,  0.0923],
        [-0.0838,  0.0584,  0.0247,  ..., -0.0630,  0.0615,  0.0827]],
       requires_grad=True)
bias forward Parameter containing:
tensor([-6.3925e-02,  9.5133e-02,  6.3453e-03, -9.6132e-02, -7.5031e-02,
         7.1918e-02,  6.5596e-02,  2.9120e-02, -3.0924e-02, -3.1459e-02,
         8.3751e-02, -5.3198e-02,  2.2847e-02,  5.0141e-02,  1.7165e-02,
         9.5857e-02,  4.8013e-02, -1.5712e-02, -9.6951e-02,  6.9367e-03,
         6.5624e-03,  2.5855e-02, -7.5122e-02, -3.0202e-02, -6.7543e-02,
        -4.3585e-02,  9.0823e-06, -5.5203e-02, -9.5422e-02,  2.1666e-02,
        -7.3739e-03, -5.1144e-02, -1.3289e-03, -8.6629e-02, -9.3888e-02,
        -5.1129e-02, -3.5457e-02, -1.9761e-02, -4.0732e-02, -7.1337e-02,
         8.7817e-03, -8.5516e-02, -9.5054e-02, -9.8342e-02,  5.7102e-02,
        -9.9832e-02, -5.9283e-02,  2.8818e-02,  1.3624e-02,  3.7801e-02,
        -3.4223e-02, -3.0157e-02,  1.8073e-02,  7.5750e-02,  7.9041e-03,
        -8.0859e-02, -9.5542e-02, -1.1279e-02,  1.0384e-02, -5.2687e-02,
         4.8081e-02, -2.6380e-02,  6.3792e-02, -3.4629e-02],
       requires_grad=True)
2 Likes

Interesting thing happens when printing inside backward:

grad_output tensor([[-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003],
        [-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003],
        [-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003],
        ...,
        [-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003],
        [-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003],
        [-0.0003, -0.0003, -0.0003,  ..., -0.0003, -0.0003, -0.0003]])
grad_output.shape torch.Size([100, 64])
type(ctx) <class 'torch.autograd.function.LinearFunctionBackward'>
ctx.saved_tensors (tensor([[-0.0130,  0.1299, -0.1079,  ..., -0.0327,  0.0201,  0.1370],
        [-0.0130,  0.1299, -0.1079,  ..., -0.0327,  0.0201,  0.1370],
        [-0.0130,  0.1299, -0.1079,  ..., -0.0327,  0.0201,  0.1370],
        ...,
        [-0.0130,  0.1299, -0.1079,  ..., -0.0327,  0.0201,  0.1370],
        [-0.0130,  0.1299, -0.1079,  ..., -0.0327,  0.0201,  0.1370],
        [-0.0130,  0.1299, -0.1079,  ..., -0.0327,  0.0201,  0.1370]],
       grad_fn=<LinearFunctionBackward>), tensor([[-0.0343,  0.0935,  0.0341,  ...,  0.0397, -0.0944,  0.0636],
        [ 0.0004,  0.0325,  0.0200,  ..., -0.0412, -0.0044,  0.0804],
        [-0.0924,  0.0477,  0.0244,  ...,  0.0977,  0.0955,  0.0971],
        ...,
        [-0.0924, -0.0543, -0.0400,  ..., -0.0085, -0.0072,  0.0443],
        [-0.0646,  0.0025, -0.0068,  ...,  0.0974, -0.0356,  0.0807],
        [ 0.0744, -0.0510, -0.0750,  ...,  0.0472,  0.0138,  0.0920]],
       requires_grad=True), tensor([ 0.0030, -0.0111, -0.0229, -0.0249, -0.0989,  0.0892,  0.0206,  0.0536,
         0.0978,  0.0473, -0.0724, -0.0930,  0.0703,  0.0183, -0.0407,  0.0490,
         0.0304, -0.0751, -0.0962,  0.0294, -0.0906, -0.0313,  0.0924, -0.0462,
        -0.0834, -0.0942, -0.0217,  0.0581,  0.0379, -0.0543,  0.0411, -0.0707,
        -0.0958,  0.0415, -0.0499, -0.0942, -0.0491,  0.0283,  0.0752,  0.0812,
         0.0318,  0.0540,  0.0340,  0.0257,  0.0744, -0.0112,  0.0081, -0.0952,
         0.0921, -0.0628,  0.0688, -0.0160,  0.0923,  0.0252,  0.0945, -0.0840,
        -0.0888, -0.0278, -0.0439,  0.0531,  0.0374, -0.0352,  0.0247,  0.0621],
       requires_grad=True))
input.shape torch.Size([100, 64])
input.requires_grad True
weight.shape torch.Size([64, 64])
bias.shape torch.Size([64])
ctx.needs_input_grad (True, True, True)
grad_input tensor([[-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        ...,
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04]])
grad_weight tensor([[ 0.0004, -0.0039,  0.0032,  ...,  0.0010, -0.0006, -0.0041],
        [ 0.0004, -0.0040,  0.0033,  ...,  0.0010, -0.0006, -0.0042],
        [ 0.0004, -0.0041,  0.0034,  ...,  0.0010, -0.0006, -0.0043],
        ...,
        [ 0.0004, -0.0043,  0.0036,  ...,  0.0011, -0.0007, -0.0046],
        [ 0.0004, -0.0040,  0.0033,  ...,  0.0010, -0.0006, -0.0042],
        [ 0.0004, -0.0036,  0.0030,  ...,  0.0009, -0.0006, -0.0038]])
grad_bias tensor([-0.0298, -0.0306, -0.0315, -0.0316, -0.0330, -0.0283, -0.0304, -0.0291,
        -0.0278, -0.0285, -0.0349, -0.0339, -0.0281, -0.0317, -0.0338, -0.0310,
        -0.0316, -0.0342, -0.0292, -0.0290, -0.0329, -0.0333, -0.0320, -0.0319,
        -0.0328, -0.0366, -0.0312, -0.0322, -0.0283, -0.0373, -0.0311, -0.0345,
        -0.0319, -0.0297, -0.0338, -0.0322, -0.0320, -0.0278, -0.0304, -0.0298,
        -0.0314, -0.0299, -0.0295, -0.0319, -0.0276, -0.0316, -0.0314, -0.0332,
        -0.0275, -0.0317, -0.0302, -0.0317, -0.0313, -0.0312, -0.0276, -0.0332,
        -0.0362, -0.0335, -0.0309, -0.0311, -0.0269, -0.0334, -0.0308, -0.0278])
grad_output tensor([[-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        ...,
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04],
        [-1.7603e-04, -6.1180e-06,  1.2589e-04,  ..., -4.7342e-05,
         -1.6849e-04, -1.2383e-04]])
grad_output.shape torch.Size([100, 64])
type(ctx) <class 'torch.autograd.function.LinearFunctionBackward'>
ctx.saved_tensors (tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]), tensor([[ 0.0842, -0.0503, -0.0517],
        [ 0.0767,  0.0256,  0.0085],
        [-0.0636,  0.0073, -0.0771],
        [ 0.0128,  0.0598,  0.0479],
        [-0.0554, -0.0358, -0.0874],
        [-0.0494,  0.0081, -0.0187],
        [-0.0284, -0.0936, -0.0620],
        [-0.0569, -0.0838, -0.0695],
        [-0.0769,  0.0653,  0.0156],
        [ 0.0631,  0.0331, -0.0328],
        [ 0.0566,  0.0832,  0.0893],
        [ 0.0138,  0.0158, -0.0760],
        [-0.0720, -0.0756,  0.0478],
        [ 0.0579, -0.0173,  0.0116],
        [ 0.0295,  0.0456,  0.0078],
        [-0.0311,  0.0026, -0.0419],
        [ 0.0587,  0.0382,  0.0799],
        [ 0.0206,  0.0552, -0.0481],
        [ 0.0311,  0.0883,  0.0899],
        [ 0.0028,  0.0687, -0.0498],
        [-0.0646, -0.0223,  0.0790],
        [ 0.0773,  0.0483,  0.0244],
        [-0.0832, -0.0044,  0.0500],
        [ 0.0530, -0.0554, -0.0592],
        [ 0.0825, -0.0022, -0.0089],
        [-0.0434,  0.0201, -0.0196],
        [-0.0718, -0.0063, -0.0274],
        [-0.0669, -0.0193,  0.0297],
        [-0.0522,  0.0782, -0.0273],
        [ 0.0184, -0.0488,  0.0038],
        [-0.0179, -0.0462,  0.0904],
        [-0.0429,  0.0611, -0.0567],
        [ 0.0525, -0.0011,  0.0973],
        [-0.0022, -0.0830, -0.0755],
        [ 0.0401,  0.0580,  0.0095],
        [ 0.0870,  0.0267,  0.0700],
        [ 0.0374,  0.0876,  0.0032],
        [ 0.0487, -0.0419,  0.0846],
        [-0.0027, -0.0206,  0.0068],
        [ 0.0997, -0.0110,  0.0023],
        [ 0.0763,  0.0883,  0.0147],
        [ 0.0381, -0.0788,  0.0771],
        [-0.0689,  0.0763, -0.0097],
        [ 0.0846, -0.0871,  0.0942],
        [ 0.0914, -0.0359, -0.0436],
        [ 0.0562, -0.0068, -0.0124],
        [ 0.0080, -0.0139, -0.0580],
        [-0.0250, -0.0011, -0.0274],
        [-0.0194,  0.0851, -0.0085],
        [-0.0947,  0.0152,  0.0870],
        [ 0.0986, -0.0572,  0.0172],
        [ 0.0699, -0.0794,  0.0199],
        [-0.0658, -0.0908, -0.0678],
        [ 0.0060,  0.0376, -0.0140],
        [-0.0589,  0.0665, -0.0053],
        [ 0.0355, -0.0619, -0.0941],
        [ 0.0111,  0.0337,  0.0725],
        [-0.0520, -0.0429, -0.0760],
        [-0.0894,  0.0772, -0.0128],
        [-0.0798, -0.0291,  0.0563],
        [-0.0549,  0.0833,  0.0315],
        [ 0.0215,  0.0446, -0.0987],
        [-0.0185, -0.0308,  0.0578],
        [-0.0027, -0.0376,  0.0975]], requires_grad=True), tensor([ 0.0048,  0.0191,  0.0255, -0.0865, -0.0007,  0.0818, -0.0146, -0.0640,
        -0.0144, -0.0680,  0.0739, -0.0047,  0.0962, -0.0705, -0.0117, -0.0876,
         0.0772, -0.0727,  0.0529, -0.0311,  0.0123,  0.0245,  0.0298,  0.0037,
         0.0168, -0.0061,  0.0558,  0.0338,  0.0988,  0.0884, -0.0939,  0.0782,
         0.0447,  0.0091,  0.0015, -0.0365, -0.0928,  0.0182, -0.0843,  0.0652,
         0.0750,  0.0602,  0.0864, -0.0776,  0.0879, -0.0737,  0.0300,  0.0027,
         0.0180, -0.0602, -0.0158, -0.0497, -0.0657, -0.0122, -0.0545, -0.0912,
        -0.0768,  0.0527,  0.0294,  0.0801, -0.0027, -0.0001,  0.0116,  0.0798],
       requires_grad=True))
input.shape torch.Size([100, 3])
input.requires_grad False
weight.shape torch.Size([64, 3])
bias.shape torch.Size([64])
ctx.needs_input_grad (False, True, True)
grad_weight tensor([[-0.0176, -0.0176, -0.0176],
        [-0.0006, -0.0006, -0.0006],
        [ 0.0126,  0.0126,  0.0126],
        [-0.0095, -0.0095, -0.0095],
        [-0.0203, -0.0203, -0.0203],
        [ 0.0125,  0.0125,  0.0125],
        [-0.0236, -0.0236, -0.0236],
        [-0.0098, -0.0098, -0.0098],
        [-0.0102, -0.0102, -0.0102],
        [ 0.0024,  0.0024,  0.0024],
        [-0.0149, -0.0149, -0.0149],
        [ 0.0075,  0.0075,  0.0075],
        [ 0.0161,  0.0161,  0.0161],
        [ 0.0039,  0.0039,  0.0039],
        [ 0.0031,  0.0031,  0.0031],
        [ 0.0025,  0.0025,  0.0025],
        [ 0.0088,  0.0088,  0.0088],
        [ 0.0129,  0.0129,  0.0129],
        [-0.0175, -0.0175, -0.0175],
        [ 0.0151,  0.0151,  0.0151],
        [-0.0161, -0.0161, -0.0161],
        [-0.0131, -0.0131, -0.0131],
        [-0.0135, -0.0135, -0.0135],
        [ 0.0094,  0.0094,  0.0094],
        [ 0.0132,  0.0132,  0.0132],
        [ 0.0022,  0.0022,  0.0022],
        [-0.0160, -0.0160, -0.0160],
        [-0.0109, -0.0109, -0.0109],
        [ 0.0171,  0.0171,  0.0171],
        [-0.0049, -0.0049, -0.0049],
        [ 0.0328,  0.0328,  0.0328],
        [ 0.0063,  0.0063,  0.0063],
        [ 0.0046,  0.0046,  0.0046],
        [-0.0100, -0.0100, -0.0100],
        [-0.0134, -0.0134, -0.0134],
        [ 0.0206,  0.0206,  0.0206],
        [ 0.0056,  0.0056,  0.0056],
        [-0.0013, -0.0013, -0.0013],
        [ 0.0190,  0.0190,  0.0190],
        [ 0.0124,  0.0124,  0.0124],
        [-0.0058, -0.0058, -0.0058],
        [-0.0196, -0.0196, -0.0196],
        [ 0.0174,  0.0174,  0.0174],
        [ 0.0144,  0.0144,  0.0144],
        [-0.0022, -0.0022, -0.0022],
        [-0.0100, -0.0100, -0.0100],
        [-0.0004, -0.0004, -0.0004],
        [-0.0023, -0.0023, -0.0023],
        [-0.0202, -0.0202, -0.0202],
        [-0.0018, -0.0018, -0.0018],
        [-0.0095, -0.0095, -0.0095],
        [ 0.0012,  0.0012,  0.0012],
        [-0.0105, -0.0105, -0.0105],
        [ 0.0061,  0.0061,  0.0061],
        [ 0.0087,  0.0087,  0.0087],
        [-0.0132, -0.0132, -0.0132],
        [-0.0198, -0.0198, -0.0198],
        [ 0.0028,  0.0028,  0.0028],
        [ 0.0237,  0.0237,  0.0237],
        [ 0.0028,  0.0028,  0.0028],
        [ 0.0136,  0.0136,  0.0136],
        [-0.0047, -0.0047, -0.0047],
        [-0.0168, -0.0168, -0.0168],
        [-0.0124, -0.0124, -0.0124]])
grad_bias tensor([-0.0176, -0.0006,  0.0126, -0.0095, -0.0203,  0.0125, -0.0236, -0.0098,
        -0.0102,  0.0024, -0.0149,  0.0075,  0.0161,  0.0039,  0.0031,  0.0025,
         0.0088,  0.0129, -0.0175,  0.0151, -0.0161, -0.0131, -0.0135,  0.0094,
         0.0132,  0.0022, -0.0160, -0.0109,  0.0171, -0.0049,  0.0328,  0.0063,
         0.0046, -0.0100, -0.0134,  0.0206,  0.0056, -0.0013,  0.0190,  0.0124,
        -0.0058, -0.0196,  0.0174,  0.0144, -0.0022, -0.0100, -0.0004, -0.0023,
        -0.0202, -0.0018, -0.0095,  0.0012, -0.0105,  0.0061,  0.0087, -0.0132,
        -0.0198,  0.0028,  0.0237,  0.0028,  0.0136, -0.0047, -0.0168, -0.0124])

Shortly:

input.requires_grad True
input.requires_grad False

Printing here:

    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        print('grad_output', grad_output)
        print('grad_output.shape', grad_output.shape)
        print('type(ctx)', type(ctx))
        input, weight, bias = ctx.saved_tensors
        print('ctx.saved_tensors', ctx.saved_tensors)
        print('input.shape', input.shape)
        print('input.requires_grad', input.requires_grad)
        print('weight.shape', weight.shape)
        print('bias.shape', bias.shape)
        grad_input = grad_weight = grad_bias = None

So, while being inside PyTorch x.requires_grad changes according to some circumstances.
How does this algorithm works to know to be safe making custom autograd.Function's?

2 Likes

So the forward method run inside a torch.no_grad() block. So even though the inputs may be marked are requiring gradient. All the elements computed inside the forward won’t require gradients.
The backward will run as any other python function (with grad enabled). So now the elements will require gradients depending on how they were computed.

Could you reformulate " How does this algorithm works to know to be safe making custom autograd.Function 's?" I’m not sure what this means

1 Like

Thank you for the answer @albanD.

It is really helpful.

My question is:

So, while being inside PyTorch x.requires_grad changes according to some circumstances.

What are those circumstances?

Inside Function, we don’t change the .requires_grad field of existing Tensors. But the forward runs in no grad mode. This means that the output an operation on a Tensor that does require gradients won’t require gradients.
The backward runs with grad mode enabled so it should behave as any other python function.

Do you see changes that you didn’t except in your case?

Yes, yes @albanD!
Printing in backward input.requires_grad, we can see that it changes states.

input.requires_grad True - second layer.
input.requires_grad False - first layer.

I am not changing it. PyTorch changing it.

Ho sorry, the input that you get is the input you pass.
In your case, in the first layer, the input is (in my example) inp which does not require grad.
In the the second layer, the input is the output of the first layer. And this Tensor does require gradients.

You called it with two Tensors with different requires_grad field, hence the difference.

1 Like

I just put my fears into the code and it seems to be working anyway. What I did:

  1. Introduced dumb var which is doing nothing at all as one of the inputs: ctx.save_for_backward(dumb, weight, input , bias) and changed return in a backward accordingly.
  2. Changed position of variables.
  3. Deleted calculation of gradients for input case.
    And everything is working perfectly:
ctx.needs_input_grad (False, True, True, True) - look at the 3rd pos it is an input.
ctx.needs_input_grad (False, True, False, True) - look at the 3rd pos it is an input.

I mean logic is completely broken in backward but I can conclude that PyTorch knows x that it is an input data and doing everything perfectly and it is not related to position of input and whether backward implemented correctly although it checks output shape of the gradient for correctness.

I can sleep well. Dear, @albanD, thank you very much for help! Have a very productive day!

Andrei

1 Like

I have made a small research on the issue:

  1. ctx.needs_input_grad True if tensor.requires_grad == True
  2. x in wx+b requires grad set to True by PyTorch in a “forward pass” which is confirmed by print statements. It might happen in __call__ or elsewhere. And in “forward pass” it also sets x.grad_fn.
    This theory, I will confirm that by looking at PyTorch source code. Thank you very much to all of you!

Andrei

Hi,

  1. Yes
  2. What might enlighten you is this function: ctx.mark_non_differentiable() that you can use during the forward pass to mark the outputs that won’t be differentiable (the outputs that won’t require gradients). All the other outputs of your autograd.Function will have requires_grad=True.
1 Like