Hi,
I am trying to use the gradient of the output of my neural network w.r.t. the inputs as part of my loss calculation. However, I noticed that this causes the bias in my final layer to receive no gradient (stuck at “None”).
I created the following example code to demonstrate this behavior:
import torch
import torch.nn as nn
class test_nn(nn.Module):
def __init__(
self,
):
super(test_nn, self).__init__()
hidden_layer_size = 6
n_hidden_layers = 4
self.in_fcn = nn.Linear(16, hidden_layer_size)
self.hidden_layers = nn.ModuleList(
modules=[
nn.Linear(hidden_layer_size, hidden_layer_size)
for _ in range(n_hidden_layers)
]
)
self.out_fcn = nn.Linear(hidden_layer_size, 1)
self.activation = nn.Softplus()
def forward(self, x):
x = self.activation(self.in_fcn(x))
for layer in self.hidden_layers:
x = self.activation(layer(x))
x = self.out_fcn(x).squeeze(0)
return x
if __name__ == "__main__":
net = test_nn()
x = torch.rand(32, 16, requires_grad=True)
out = net(x)
loss = nn.MSELoss()(out, torch.rand(32, 1))
loss.backward()
print("Grads of normal output")
print("out bias grad: ", net.out_fcn.bias.grad)
print("in bias grad: ", net.in_fcn.bias.grad)
net.zero_grad()
out = net(x)
grad = torch.autograd.grad(
out,
x,
grad_outputs=torch.ones_like(out),
create_graph=True,
retain_graph=True,
)[0]
print(grad.shape)
loss = nn.MSELoss()(grad, torch.rand(32, 16))
loss.backward()
print("Grads of normal output")
print("out bias grad: ", net.out_fcn.bias.grad)
print("in bias grad: ", net.in_fcn.bias.grad)
Which results in the following output (obviously the actual values of the gradients are variable between runs):
Grads of normal loss
out bias grad: tensor([0.4833])
in bias grad: tensor([-7.7400e-05, 3.6427e-04, 2.8540e-04, -9.1523e-05, -8.3902e-04,
3.4275e-06])
torch.Size([32, 16])
---------------------
Grads of grad loss
out bias grad: None
in bias grad: tensor([-7.0170e-07, 6.5941e-06, -1.0075e-05, -5.6637e-06, 2.2724e-05,
2.1236e-06])
Does anybody know if this is expected behavior that I am overlooking? Or is there something else going wrong?