Tensor forward hook


Is the “register_forward_hook” exist for tensor type?

I couldn’t find it.



There is no such thing no. Why would you want such a hook? If you have the Tensor, you have its value. There is no need for a forward hook.

In my use case,

class M(nn.Module):
  def __init__(self):
    self.l = nn.Linear(10, 5)
    self.l2 = nn.Linear(5, 1)

  def forward(self, x):
    x = self.l(x)
    x = torch.exp(x)   # Want to hook this tensor in forward pass
    x = self.l2(x)
    return x

Actually, I could make forward hook by separating the former part into a single module something like this.

class FirstLayer(nn.Module):
  def __init__(self):
    self.l1 = nn.Linear(10, 5)

  def forward(self, x):
    x = self.l1(x)
    x = torch.exp(x)
    return x

m = FirstLayer()
m.register_forward_hook(lambda x: x)

However, I want to keep my module single so that I don’t care about other module.


But this line x = torch.exp(x) # Want to hook this tensor in forward pass is actually executed every time the forward is called. So you can put whatever code you want next to that line and it will run during the forward.

1 Like

I think I elaborate on my example more precisely… First, the reason that I create this thread is to make a code more flexible.

I just want to collect some activation for visualization.

Something like this,

m = Model()
act = dict()
def f(m, i, o):
  act["m"] = o.detach().cpu()


If the target activation comes from module, there is no problem because register_forward_hook is in there. However, how to do the same thing for tensor? In my knowledge, this is the best one…

class Model(nn.Module):
  def __init__(self, act):
    self.act = act

  def register_act(self, act):
    self.act = act

  def forward(self, x):
    x = torch.something(x)
    self.act["m"] = x.detach()
    return x

While it works, I don’t like it :blush: because the process of forward depends on self.act. Whenever I use this module, I have to consider the act. I think the hook makes us removing this irritation. I couldn’t figure out the appropriate use case of tensor version register_forward_hook for tensor, which is the reason that lots of specialists don’t implement this functionality.

Thanks anyway, your advice give me lots of useful insight.

My point is more that you cannot do such an API. By the time to have the Tensor so that you could register a hook on it, it’s content is available. So there is no need to hook anymore.
And if you want to place a hook beforehand, you have no Tensor to place it onto because the Tensor does not exist yet.

1 Like

As @albanD explained, you cannot register a hook on a tensor, as it’s already available.
Based on your previous code snippet I guess you could use an nn.Identity layer, pass x to it after the exp operation and register a hook on this “fake” layer. I don’t think it would be cleaner than storing the tensor as in your last code snippet, but you might prefer this approach.

1 Like

That is the fresh thought, I didn’t aware of the existence of nn.Identity. :blush:

Maybe the code will be like this… (This is for my future. I usually revisit my thread to copy some snippet)

class Model(nn.Module);
  def __init__(self):
    self.w = nn.Linear(10, 5)
    self.tracker_act = nn.Identity("exp_after_linear")

  def forward(self, x):
    x = self.w(x)
    x = self.tracker_act(torch.exp(x))
    return x

act = {}
def collect(m, i, o):
  return act["exp_after_linear"] = o.detach().cpu()

model = Model()

I think it is nice. :blush:
Thanks for the nice discussion, @albanD and @ptrblck.

1 Like

You can also implement a wrapper class for a function.
This also allows you to keep track of connections between modules.
Calling Identity after torch.exp returns (via hook) two identical tensors – *input_tensors, output_tensor.
The connection to the previous module, by sharing the same tensor of previous module output and input of considered module, is lost.
Using a wrapper module allows for a continuous connection.
*Until the previous tensors are released from memory and the new ones get identical ptrs.

import torch

class Wrapper(torch.nn.Module):
    def __init__(self, operation, description:str=""):
        self.operation = operation
        self.description = description

    def extra_repr(self,):
        return f"operation={self.description}, description=\"{self.description}\""

    def forward(self, *x):
        return self.operation(*x)

class Example(torch.nn.Module):
    def __init__(self):
        self.l = torch.nn.Linear(10, 5)
        self.exp = Wrapper(torch.exp,"torch.exp")
        self.l2 = torch.nn.Linear(5, 1)
        self.square = Wrapper(torch.square,"torch.square")

    def forward(self, x):
        x = self.l(x)
        x = self.exp(x)
        x = self.l2(x)
        x = self.square(x)
        return x

def hook(module,in_tensors,out_tensor):
    print(module,[t.data_ptr() for t in in_tensors], out_tensor.data_ptr())

if __name__ == '__main__':
    m = Example()
    hooks = []
    for ch in m.children():
    t = torch.rand(1,10)
    print("Input t:", t.data_ptr())
    o = m(t)
    print("Output o:", o.data_ptr())
    for h in hooks:
Input t: 102179200
Linear(in_features=10, out_features=5, bias=True) [102179200] 102739072
Wrapper(operation=torch.exp, description="torch.exp") [102739072] 103698048
Linear(in_features=5, out_features=1, bias=True) [103698048] 103786048
Wrapper(operation=torch.square, description="torch.square") [103786048] 104697856
Output o: 104697856