Print intermediate gradient values during backward pass in Pytorch using hooks

I am trying to print the value of each of the intermediate gradients during backward pass of a model, using register backward hooks:

class func_NN(torch.nn.Module):
    def __init__(self,) :
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1,1)*inp)
        sum_x = mul_x - self.b
        return sum_x

# hook function
def backward_hook(module, grad_input, grad_output):
    print("module: ", module)
    print("inp: ", grad_input)
    print("out: ", grad_output) 

# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)

t_l = []
for i in range(2):
    l = loss(y, foo.forward(inp=inp))

But this does not provide the desired result.

My objective is to print the gradients of the non-leaf nodes like sum_x and mul_x.
Please help.

You can print the gradient directly via its .grad attribute after calling .retain_grad() on these tensors.

You are using module backward hooks which only provide gradient information at the granularity of a module. If you want to look at the gradients wrt specific operators you can look at tensor/grad_fn hooks. See here Autograd mechanics — PyTorch 2.1 documentation