Understanding how to use grad when getting grad wrt inputs

Hi all, I’m trying to use autograd to calculate the gradient of some outputs wrt some inputs on a
pretrained neural network. Here is the code I use:

net = Custom_NN(input_size, hidden_layers, output_size)
net.load_state_dict(torch.load('network_weights.pt'))

x = torch.rand(input_size, requires_grad=True)
y = net(x)
gradient = torch.autograd.grad(y, x, grad_outputs = torch.ones_like(y))

This code gives me the following error:

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

I’ve consulted the following questions and it seems like I’m doing the same thing as them, but using their solutions doesn’t fix the error, so it seems like there’s something I’m fundamentally misunderstanding about how to use autograd. Can anyone help me understand what autograd is doing, why it cares about unused Tensors, and how to get the network to use my input Tensors in its graph? Thanks.

Could you try,

net = net.load_state_dict(torch.load('network_weights.pt'))

and try again? (and also share what Custom_NN is)

Sure. Running it the way you have it throws

TypeError: ‘_IncompatibleKeys’ object is not callable

after the line

y = net(x)

(In this case type(net) is torch.nn.modules.module._IncompatibleKeys).

The class definition for my model is as follows:

class Custom_NN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, scalers=None):
        super(VanillaNet, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        if scalers is None:
            # Here I've hardcoded the mean_'s and scale_'s of input/output scalers.
            self.inp_mean = torch.Tensor([
                -4.59936018e+00,  1.31706962e+02,  3.80556213e+05,  1.05726545e+05,
                1.22070403e+01,  1.49188727e+02,  1.48334349e+01,  1.73222231e+04
            ])
            self.inp_std = torch.Tensor([
                5.35557343e-01, 2.00745698e-01, 1.22149250e+05, 2.86161126e+04,
                3.68246696e+00, 2.23304735e+01, 3.15135316e+00, 5.14069803e+03
            ])
            
            self.out_mean = torch.Tensor([ 
                1.70409856,  2.36874371, 14.85654644,  5.24019232,  1.51261909,
                1.0671637 , 39.82251636,  4.39834804,  0.99539551,  1.37684262,
                2.83503784
            ])
            self.out_std = torch.Tensor([
                0.54293611, 0.33864151, 3.57840622, 1.48489459, 0.31817987,
                0.20210712, 2.12708213, 0.6290501 , 0.15368718, 0.21272585,
                0.61394896
            ])
        else:
            inp_scaling, out_scaling = scalers
            self.inp_mean, self.inp_std = inp_scaling
            self.out_mean, self.out_std = out_scaling
        
        self.normalize_input = lambda x: (x - self.inp_mean)/self.inp_std
        self.unnormalize_output = lambda x: x*self.out_std + self.out_mean
        
        self.input_layer = torch.nn.Linear(self.input_size, self.hidden_size[0])
        self.input_act = torch.nn.ReLU()
        self.hidden_layers = torch.nn.Sequential(
            *chain.from_iterable(
                [torch.nn.Linear(a,b), torch.nn.ReLU()] for a,b in zip(hidden_size[:-1], hidden_size[1:])
            )
        )
        self.output_layer = torch.nn.Linear(self.hidden_size[-1], self.output_size)

        
    def forward(self, x):
        with torch.no_grad():
            normalized_x = self.normalize_input(x)
        y = self.input_act(self.input_layer(normalized_x))
        z = self.hidden_layers(y)
        return self.output_layer(z)
    
    
    def forward_unnormalized(self, x):
        with torch.no_grad():
            return self.unnormalize_output(self.forward(x))

This is the issue. You’re breaking your comp graph, which is why you get,

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Remove the torch.no_grad context manager and try again (and be careful of potential divide by zero errors on self.inp_std).

That was it! Thank you so much!

Hi @jcallahan4,
Good to see your error is already solved.

Since you wanted to understand what autograd is doing, and how to get the network to use the input tensors in the computation graph, I’m adding some details in that regard:

torch.autograd is pytorch’s automatic differentiation engine that, as the name suggests, deals with automatically calculating gradients for any “computational graph”.

Computational graphs are what that get build by autograd as and when tensors are subjected to mathematical operations. While building these graphs, autograd also saves tensors that’ll be required to calculate the gradients wrt tensors having their requires_grad attribute set to True.

(So, when you use torch.autograd.grad or use the .backward call, these saved tensors are used).

Now, torch.no_grad() basically tells autograd to look away. It can be used as a context manager so that for any piece of code occurring within this context, autograd shall build no graph (or will not further populate any graph that’s already there).
i.e. It’ll not track any operations.

Now, for your code, you are differentiating y (output) wrt x,
where y = net(x) which essentially means y = net.forward(x).

Inside forward, output_layer(z) which is returned (and hence is essentially what gets stored in y = net(x)) is a result of operations on normalized_x, but normalized_x is getting created as a result of operations on x under torch.no_grad().

This means even if
y = self.input_act(self.input_layer(normalized_x)) and
z = self.hidden_layers(y) are a part of the computation graph, normalized_x isn’t really.

And so when you tried to differentiate y (which is returned by forward) wrt x, it produced the error
One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

So, here the error prompt is most probably talking about normalized_x as the tensor that appears to not have been used in the graph.

Note: Even if normalized_x is getting created as a result of operations on x whose requires_grad is set to True, it doesn’t matter. Under torch.no_grad(), nothing is tracked by autograd and so all resulting tensors have their requires_grad=False.

Hope this helps,
Srishti