gslaller
(gsl)
October 11, 2019, 3:52pm
1
Given the following code:
class Example(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1,3)
self.l2 = nn.Linear(3,2)
def forward(self, x):
mid = self.l1(x)
res = self.l2(mid)
return res
is there a way to have a hook the the mid
variable in the forward call?
something like:
example.forward.mid.register_hook(lambda x: print(x))
1 Like
albanD
(Alban D)
October 12, 2019, 10:34pm
2
Hi,
Why not add mid.register_hook(lambda x: print(x))
to your forward
function?
You can also return mid
from the function so that the caller can add the hook.
gslaller
(gsl)
October 20, 2019, 3:49pm
3
I was going through a code of someone else, this is what i ended up doing.
SunQpark
(Seonkyu Park)
October 21, 2019, 10:20am
4
How about example.l1.register_hook(lambda x: print(x))
, since example.l1
is also an nn.Module
albanD
(Alban D)
October 21, 2019, 3:25pm
5
register_hook
do not exist for nn.Module
. And the closest thing that exists, register_backward_hook()
is not working properly at the moment and should not be used.
SunQpark
(Seonkyu Park)
October 22, 2019, 2:19pm
6
register_forward_hook
takes more complex functions than lambda x: print(x)
as input. Try this example.
import torch
import torch.nn as nn
class Example(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1,3)
self.l2 = nn.Linear(3,2)
def forward(self, x):
mid = self.l1(x)
res = self.l2(mid)
return res
if __name__ == "__main__":
model = Example()
dummy_input = torch.randn([1])
print(dummy_input)
def hook(module, input, output):
print(output)
model.l1.register_forward_hook(hook)
out = model(dummy_input)