Why does autograd.grad() clear the requires_grad attribute of outputs?

I’m trying to implement the Curveball optimizer in PyTorch 0.4.0, but I’m running into unexpected behavior from autograd.grad.

Below is a minimal working example that illustrates the problem I’ve encountered.

The algorithm requires that I compute the product of a vector and the Hessian of the loss function with respect to the network’s output, which means using higher order gradients if I want code that is agnostic about the loss function. To do this, I need to retain the gradient for the network output and then use calls like HlJfTz = ag.grad(output.grad,output,JfTz,retain_graph=True) to do the multiplication.

The first time I do the Hessian-vector product, it all works as expected. However, when I try to compute a second Hessian-vector product, it fails saying that output.grad.requires_grad is False.

It appears that output.grad.requires_grad changes from True to False after the statement Gz = ag.grad(output,self.model.parameters(),HlJfTz,retain_graph=True).

I’m wondering if this is intended behavior, and if anyone can suggest a workaround.

import torch
import torch.autograd as ag
import torch.nn as nn

class StochasticRosenbrock(nn.Module):
    def __init__(self,u,v,lo=0.0,hi=1.0):
        super(StochasticRosenbrock, self).__init__()
        self.epsilon = torch.Tensor([1.0])
        self.lo = lo
        self.hi = hi
        self.w  = nn.Parameter(torch.Tensor([u,v]),requires_grad=True)
        return

    def forward(self,sample=True):
        out = torch.zeros(2)
        u = self.w[0]
        v = self.w[1]
        if sample:
            self.epsilon.uniform_(self.lo,self.hi)
        out[0] = 1.0 - u
        out[1] = 10.0 * self.epsilon * (v - u**2)
        return out

class CurveBall(object):
    def __init__(self,model,criterion,lambda_=1.0):
        self.model = model
        self.z = [torch.zeros_like(p) for p in model.parameters()]
        self.criterion = criterion
        self.lambda_ = lambda_
        return

    def step(self):
        target = torch.zeros(2)
        output = self.model()
        output.retain_grad()
        loss = self.criterion(output,target)
        loss.backward(create_graph=True)
        J = [p.grad for p in self.model.parameters()]
        JfTz = ag.grad(J,output,self.z,retain_graph=True)
        HlJfTz = ag.grad(output.grad,output,JfTz,retain_graph=True)
        # this line mucks up the gradient tracking for output, and breaks the computation of HlJfTdeltaz
        print(output.grad.requires_grad)
        Gz = ag.grad(output,self.model.parameters(),HlJfTz,retain_graph=True)
        print(output.grad.requires_grad)
        deltaz = [j.detach() + g + self.lambda_ * z for j, g, z in zip(J, Gz, self.z)]
        JfTdeltaz = ag.grad(J,output,deltaz,retain_graph=True)
        # fails because output.grad.requires_grad is False
        HlJfTdeltaz, = ag.grad(output.grad,loss,JfTdeltaz,retain_graph=True)
        Gdeltaz, = ag.grad(output,self.model.parameters(),HlJfTdeltaz,retain_graph=True)
        # cut here because this is where it breaks

cnn = StochasticRosenbrock(-0.25,-0.25,1.0,1.0)
cnn.train()
criterion = nn.MSELoss(size_average=False)
optimizer = CurveBall(cnn,criterion)
loss = optimizer.step()

Thanks!

Hi,

I think the option you want to use is create_graph not retain_graph. Doc is here. I guess output is part of the graph you go through and so it messes things up as it’s not creating a higher order graph, just keeping the existing one.