Instrumenting and modifying forward pass

Dear all,

I have a question, and I hope that I am in the best forum for that (in the other case, let me know). I hope the answer was not already addressed but I did not find it.

Let’s say that I have a network backbone, VGG for instance but actually any network. I want to intercept the result between too layers (known), modify it and reinject it at the same place to go ahead with the forward computation. I looked to callbacks but it seems that it is not the right tool.

Without modifying the original code (actually, this is my main concern, I want to be concervative on the original code), is it feasible?

Best regards. Doms.

Hi, you can modify the forward pass by registering a forward hook.

I quote the doc FYI.

The hook will be called every time after forward() has computed an output.
The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called.

I attached a demo where I modified the output of the last layer(net.fc3). You can replace the layer name with the layer name you want to intercept.

import torch
from torch.nn import Module
import torch.nn as nn


class Net(Module):
    def __init__(self) -> None:
        self.fc1 = nn.Linear(4, 5)
        self.fc2 = nn.Linear(5, 4)
        self.fc3 = nn.Linear(4, 2)
    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x1)
        x3 = self.fc3(x2)
        return x3
net = Net()

x = torch.randn(2, 4)
y = net(x)
print(f"Original output: \n {y}")

def modify_tensor(module, input, output):
    output = output * 100
    return output

hook_handle2 = net.fc3.register_forward_hook(modify_tensor)
out = net(x)
print(f"Modified output: \n {out}")

# Original output: 
#  tensor([[-0.1258, -0.1067],
#         [-0.3273, -0.3938]], grad_fn=<AddmmBackward0>)
# Modified output: 
#  tensor([[-12.5774, -10.6738],
#         [-32.7307, -39.3759]], grad_fn=<MulBackward0>)


Thank you very much! I will try this. I was sure that I did not search with the right keyword (hook in this case).

Thank you again for your help.

Best regards.