I am trying to understand the impact of .detach()
on gradients and have set up an example as follows:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Linear(784, 100),
nn.ReLU(inplace=True),
nn.Linear(100, 100),
nn.ReLU(inplace=True),
)
def forward(self, x):
# The original forward function
x = self.features(x)
# The modified forward function including the .detach() method
#x = self.features(x) + (1 - self.features(x)).detach()
return x
torch.manual_seed(0)
x = torch.randn(784).view(1, 784).requires_grad_(True)
net = Net()
print(x.grad)
u = net(x)
loss = u.abs().sum()
print(loss)
loss.backward()
print(x.grad)
Regardless of whether I run the original forward
function or the modified one (which is commented out above), I get the same values for x.grad
. This is despite the fact the the loss is about 10 in one case and 100 in the other.
I understand that the way my forward
method is defined, backward sees self.features(x)
as the forward method in both cases. The part I’m confused about is why having two different loss values and the same backward function results in the same gradient values for x
.
Is this the expected behavior or a different loss value should result in a different x.grad
value?