Weird backward hook behavior

I was just trying a minimal example to see how forward and backward is done on LSTM using the code below:

# coding: utf-8

import torch.nn as nn

import torch

def fhook(self, inputs, outputs):
    print()
    print('module:', self.__class__.__name__)
    print()
    print('inputs:')
    print([None if i is None else i.shape for i in inputs])
    print()
    
def printgrad(self, grad_input, grad_output):
    print()
    print('module:' + self.__class__.__name__)
    print('')
    # the gradients of inputs to the module
    print('grad_inputs: ')
    print([None if gi is None else gi.shape for gi in grad_input])
    print('grad_input: ')
    print(grad_input)
    # grads of outputs of the module
    print('grad_outputs: ')
    print([None if go is None else go.shape for go in grad_output])
    print('grad_output: ')
    print(grad_output)
    print('')
    print('grad_input norm:', grad_input[0].norm())
    print()


class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.rnn = nn.LSTM(5, 10, batch_first=True, bidirectional=False)
        
    def forward(self, inputs):
        _, (out, _) = self.rnn(inputs)
        return out

model = M()

for m in model.modules():
    print(m)
    m.register_forward_hook(fhook)
    m.register_backward_hook(printgrad)

x = torch.randn(1, 30, 5)

loss = model(x).sum()

loss.backward()

print(model.rnn.__dict__)

print(model.rnn.weight_ih_l0.grad)

print(model.rnn.weight_hh_l0.grad)

Strangely, at the loss.backward() step the input and output gradients of the LSTM were not printed, but the gradients were computed, as verified by the last two prints. Any insights on this? Thanks!

Hi,

Things get printed for me though you should not use register_backward_hook() here. As per the doc, they are a bit broken at the moment and will give bad results for complex modules (such as yours).
To get proper hooks, you should hook on Tensors directly, see the doc. You can use that both in the inputs and out of your module to get all the gradients you need.