PyTorch forward hook cannot capture all input variables

I have a program with a sub neural network whose forward method takes three input arguments:

def forward(self, input, hidden,
                transformed_input=None):

But in reality it was fed with the latter two arguments, namely hidden and tranformed_input.
I registered one forward hook and found only two inputs have been captureds:

def hook(module, input, output):
    print(len(input)) #it outputs 2

When I am trying to print input in my forward hook:

def hook(module, input, output):
    #print(module)
    print(input)

It outputs a tuple of size 2 with (None, (some tensor, some tensor)).

Notice in my program hidden does have this kind of (some tensor, some tensor) shape. My best guess so far is that PyTorch ignores some parts whose default value is None. Since in my program in the forward pass input was fed with None and transformed_input 's default vallue is None, hence causing the above observation.

However, I did another experiement:

class Test(nn.Module):
  def __init(self):
    pass
  def forward(self, input1, input2, input3=None):
    if input1 is not None:
      return input1+input2
    else:
      return input2

a=Test() 

def hook(module, input, output):
  print(module)
  print(len(input))
  print(input)

a.register_forward_hook(hook)
a(None, torch.tensor([1]), torch.tensor([9]))
# Test()
# 3
# (None, tensor([1]), tensor([9]))
#tensor([1]) 

which correctly output everything. So what are your thoughts on this?

My guess it you might be registering or looking at the wrong module in your initial code, as the small code snippet seems to work.

Try to add a name to the hook so that the print statement will show you which layer was called with which inputs and outputs.

I think I found the bug.
My new hook looks like this:

def hook(module, input, output):
    print(module)
    print(len(input))
    print(input)

When in the forward contains the explicit variable name:

cell(None, (torch.zeros(1,1150), torch.zeros(1,115,10)), transformed_input=cell.ih(torch.ones(1,400)))

It outputs:

ONLSTMCell(
(ih): Sequential(
(0): Linear(in_features=400, out_features=4830, bias=True)
)
(hh): LinearDropConnect(in_features=1150, out_features=4830, bias=True)
)
2
(None, (tensor([[0., 0., 0., …, 0., 0., 0.]]), tensor([[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]]])))

When variable name in forward method is dropped:

cell(None, (torch.zeros(1,1150), torch.zeros(1,115,10)), cell.ih(torch.ones(1,400)))

ONLSTMCell(
(ih): Sequential(
(0): Linear(in_features=400, out_features=4830, bias=True)
)
(hh): LinearDropConnect(in_features=1150, out_features=4830, bias=True)
)
3
(None, (tensor([[0., 0., 0., …, 0., 0., 0.]]), tensor([[[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]]])), tensor([[ 0.2503, 0.5609, -0.3321, …, 0.2259, -0.2337, -0.3209]],
grad_fn=))

1 Like

And I have already issued this bug on PyTorch’s github repo

1 Like

I was able to generate an example that reproduces the problem and after some digging …

It is a reflection of the problems discussed here:

Issues In the pytorch github:

  • 59923
  • 35643

and it seems like a fix for kwargs to be handled correctly by the hooks is in the making

TLDR: Keywords argument dont show up in hooks, fix your code until they do

import torch
from torch import nn


class BorinModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(in_features=10, out_features=1)
        self.linear2 = nn.Linear(in_features=20, out_features=1)

    def forward(self, x, y, debug=True):
        if debug:
            print("im debuggin")

        return self.linear1(x) + self.linear2(y)

class EvenMoreBoringModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.boring_stuff = BorinModel()

    def forward(self, x, y):
        return self.boring_stuff(x, y, debug=True)

    def forward2(self, x, y):
        return self.boring_stuff(x=x, y=y, debug=True)

def hook(m, i, o):
    print(m)
    print(f"Inputs: {[x.shape for x in i]}")
    print(f"Outputs: {o.shape}")

boring_instance = EvenMoreBoringModel()

handle = boring_instance.boring_stuff.register_forward_hook(hook)

boring_instance.forward(torch.ones(5,10), torch.ones(5, 20))
# im debuggin
# BorinModel(
#   (linear1): Linear(in_features=10, out_features=1, bias=True)
#   (linear2): Linear(in_features=20, out_features=1, bias=True)
# )
# Inputs: [torch.Size([5, 10]), torch.Size([5, 20])]
# Outputs: torch.Size([5, 1])
boring_instance.forward2(torch.ones(5,10), y = torch.ones(5, 20))
# im debuggin
# BorinModel(
#   (linear1): Linear(in_features=10, out_features=1, bias=True)
#   (linear2): Linear(in_features=20, out_features=1, bias=True)
# )
# Inputs: []
# Outputs: torch.Size([5, 1])