I’m trying to implement the Curveball optimizer in PyTorch 0.4.0, but I’m running into unexpected behavior from autograd.grad
.
Below is a minimal working example that illustrates the problem I’ve encountered.
The algorithm requires that I compute the product of a vector and the Hessian of the loss function with respect to the network’s output, which means using higher order gradients if I want code that is agnostic about the loss function. To do this, I need to retain the gradient for the network output and then use calls like HlJfTz = ag.grad(output.grad,output,JfTz,retain_graph=True)
to do the multiplication.
The first time I do the Hessian-vector product, it all works as expected. However, when I try to compute a second Hessian-vector product, it fails saying that output.grad.requires_grad
is False
.
It appears that output.grad.requires_grad
changes from True
to False
after the statement Gz = ag.grad(output,self.model.parameters(),HlJfTz,retain_graph=True)
.
I’m wondering if this is intended behavior, and if anyone can suggest a workaround.
import torch
import torch.autograd as ag
import torch.nn as nn
class StochasticRosenbrock(nn.Module):
def __init__(self,u,v,lo=0.0,hi=1.0):
super(StochasticRosenbrock, self).__init__()
self.epsilon = torch.Tensor([1.0])
self.lo = lo
self.hi = hi
self.w = nn.Parameter(torch.Tensor([u,v]),requires_grad=True)
return
def forward(self,sample=True):
out = torch.zeros(2)
u = self.w[0]
v = self.w[1]
if sample:
self.epsilon.uniform_(self.lo,self.hi)
out[0] = 1.0 - u
out[1] = 10.0 * self.epsilon * (v - u**2)
return out
class CurveBall(object):
def __init__(self,model,criterion,lambda_=1.0):
self.model = model
self.z = [torch.zeros_like(p) for p in model.parameters()]
self.criterion = criterion
self.lambda_ = lambda_
return
def step(self):
target = torch.zeros(2)
output = self.model()
output.retain_grad()
loss = self.criterion(output,target)
loss.backward(create_graph=True)
J = [p.grad for p in self.model.parameters()]
JfTz = ag.grad(J,output,self.z,retain_graph=True)
HlJfTz = ag.grad(output.grad,output,JfTz,retain_graph=True)
# this line mucks up the gradient tracking for output, and breaks the computation of HlJfTdeltaz
print(output.grad.requires_grad)
Gz = ag.grad(output,self.model.parameters(),HlJfTz,retain_graph=True)
print(output.grad.requires_grad)
deltaz = [j.detach() + g + self.lambda_ * z for j, g, z in zip(J, Gz, self.z)]
JfTdeltaz = ag.grad(J,output,deltaz,retain_graph=True)
# fails because output.grad.requires_grad is False
HlJfTdeltaz, = ag.grad(output.grad,loss,JfTdeltaz,retain_graph=True)
Gdeltaz, = ag.grad(output,self.model.parameters(),HlJfTdeltaz,retain_graph=True)
# cut here because this is where it breaks
cnn = StochasticRosenbrock(-0.25,-0.25,1.0,1.0)
cnn.train()
criterion = nn.MSELoss(size_average=False)
optimizer = CurveBall(cnn,criterion)
loss = optimizer.step()
Thanks!