Hi, I’m trying to get a hessian vector product of a network (for TRPO). But the following codes don’t work as expected. Do I miss something important? Does anybody know how to solve it? Thank you in advance.
Environment
- Mac and Ubuntu
- Python 3.6
- PyTorch 0.4.0
Conv case
conv = nn.Conv2d(3, 64, 3)
input = torch.randn(1, 3, 32, 32)
out = conv(input).sum()
grads = autograd.grad([out], conv.parameters(), create_graph=True)
flatten = torch.cat([g.reshape(-1) for g in grads if g is not None])
x = torch.randn_like(flatten)
print(flatten.shape) ## torch.Size([1792])
hvps = autograd.grad([flatten @ x], conv.parameters(), allow_unused=True)
print(hvps[1]) ## None
flatten2 = torch.cat([g.reshape(-1) for g in hvps if g is not None])
print(flatten2.shape) ## torch.Size([1728])
In this case, the gradients of conv.bias
are None
.
Linear case
linear = nn.Linear(10, 20)
input = torch.randn(1, 10)
out = linear(input).sum()
grads = autograd.grad([out], linear.parameters(), create_graph=True)
flatten = torch.cat([g.reshape(-1) for g in grads if g is not None])
x = torch.randn_like(flatten)
print(flatten.shape)
hvps = autograd.grad([flatten @ x], linear.parameters(), allow_unused=True)
Here, I got the following message.
Traceback (most recent call last):
File "fvp.py", line 24, in <module>
hvps = autograd.grad([flatten @ x], linear.parameters(), allow_unused=True)
File "/Users/foo/.miniconda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 144, in grad
inputs, allow_unused)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn