Understanding gradients when .detach() is used

I am trying to understand the impact of .detach() on gradients and have set up an example as follows:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.features = nn.Sequential(
            nn.Linear(784, 100),
            nn.ReLU(inplace=True),
            nn.Linear(100, 100),
            nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        # The original forward function
        x = self.features(x)

        # The modified forward function including the .detach() method
        #x = self.features(x) + (1 - self.features(x)).detach()
        
        return x
torch.manual_seed(0)
x = torch.randn(784).view(1, 784).requires_grad_(True)

net = Net()
print(x.grad)

u = net(x)
loss = u.abs().sum()
print(loss)
loss.backward()

print(x.grad)

Regardless of whether I run the original forward function or the modified one (which is commented out above), I get the same values for x.grad. This is despite the fact the the loss is about 10 in one case and 100 in the other.

I understand that the way my forward method is defined, backward sees self.features(x) as the forward method in both cases. The part I’m confused about is why having two different loss values and the same backward function results in the same gradient values for x.

Is this the expected behavior or a different loss value should result in a different x.grad value?

Hi,

As soon as you use detach, you don’t get “real” gradients and the whole thing becomes quite hard to reason about.

The main reason for you to get the same value, is that in the lower part of your function, the gradient does not depend on the value. Try changing your loss to a squared norm, and you will see a different behavior.

You can check this work in progress of the revamp of the autograd doc that tries to define detach in a more mathematical way: here.

1 Like

Thanks for your answer. I modified the loss function to torch.norm() and observed different gradients at the input.

However, I have difficulty understanding the following part of your answer:
“The main reason for you to get the same value, is that in the lower part of your function, the gradient does not depend on the value.”
Do you mind sharing pointers to resources that explain this part in more details?

Thank you!

The idea is that the gradient of 2*x is 2. So whatever is the value of x during the forward, it does not change the value of the gradient.
On the other hand, for x**2, the gradient is 2*x and so depending on the value x used in the forward, the gradient will change.

1 Like