Gradient of module output with regards to module input

Hi everyone,

First question here, I hope it’s the correct place.

I have a chain of nn.Sequential and I’m trying to extract 2 gradients from my model: gradient of the module output wrt the module input and the gradient of the module output wrt x (initial input).

In other words, given:

y1 = f1(x)
y2 = f2(y1)
# and so on

where f_i is an nn.Sequential (or any nn.Module), I’d like to extract dy2/dy1 (module output wrt module input) and dy2/dx (module output wrt x).

Thanks!
Alex

Hi Alex,
You might be looking for torch.autograd.grad :
https://pytorch.org/docs/stable/generated/torch.autograd.grad.html

Feel free to post in the thread if you face any errors.

Hi Srishi,

Thanks for your help. After going over that doc page, I came up with this tiny piece of code to test it out:

model = nn.Sequential(
    nn.Sequential(
        nn.Linear(12, 14),
        nn.Linear(14, 16),
    ),
    nn.Sequential(
        nn.Linear(16, 16),
        nn.Linear(16, 20),
    ),
    nn.Sequential(
        nn.Linear(20, 20),
        nn.Linear(20, 20),
    ),
).to(device)

grads = []
def hook2(module, inputs, outputs):
#     import pdb
#     pdb.set_trace()
    grads.append(torch.autograd.grad(outputs, inputs[0]))
    
grads.clear()
for module in model.modules():
    if isinstance(module, nn.Sequential):
        module.register_forward_hook(hook2)

However, I’m getting an error:
RuntimeError: grad can be implicitly created only for scalar outputs

Not sure how to proceed from here.

Hi @alexandrumeterez ,
Yes, that’s expected behaviour when you try to use autograd to calculate gradient of a tensor that isn’t a single number (scalar).

In case you want to be able to calculate gradient of a multi-dim tensor, use this:

import torch
inp = torch.tensor([4.0, 3.0, 2.0], requires_grad=True)
x = (inp*2)**3  # x is not a scalar
x.backward(torch.ones_like(x))
print(inp.grad) # tensor([384., 216.,  96.])

In fact, when x is scalar, x.backward() is a shortcut for x.backward(torch.Tensor([1])).
Hope this helps,
S

1 Like

This is what I was looking for, thank you!