How to clear the modules output history when using register_forward_hook?

First I use a pytorch pretrained Resnet, then I use these codes to get the hidden feature.

feat_out = []

def hook_fn_forward(module, input, output):
    feat_out.append(output)
    print(output)

modules = model.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)

pred = model(x)

But when I run these codes the first time, len(feat_out) gives me 10, and print in hook function prints 10 lines. If I run these codes again, len(feat_out) gives me 20, and print in hook function prints 20 lines. Every time I run, the length of output in hook function increase by 1. The output is output in this time plus all past output. Only if I reinitialize the model and run these codes, the past output history will be removed.

How can I clear the output every time I run the model?

I use these codes in colab to reproduce this problem in minimum length (5 lines to load data, 2 lines to initialize model, 8 lines for this problem).

Since you are initializing the feat_out list as a global object, this behavior is expected.
You could either reinitialize it after the first forward pass or use a dict instead, which can be easily used to replace the outputs.
Also note, if you are not running the forward pass in a with torch.no_grad() block, the output tensors will stay attached to the computation graph, which will prevent PyTorch from deleting this graph after the backward call. If you don’t need these output tensors for a gradient calculation you could alternatively store the detached version via feat_out.append(output.detach()).

2 Likes

These are my full code, in case you don’t want to go to Colab.

!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip -q hymenoptera_data.zip

data_transforms = transforms.Compose([
                             transforms.RandomResizedCrop(224),
                             transforms.ToTensor()])
data_dir = 'hymenoptera_data//train'
image_datasets = datasets.ImageFolder(os.path.join(data_dir), transform=data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, num_workers=4)

model_ft = models.resnet18(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 2) #  size of each output is set to 2

dat, lab = next(iter(dataloaders))


total_feat_out = []

def hook_fn_forward(module, input, output):
    total_feat_out.append(output) 
    print(output.shape)

modules = model_ft.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)

model_ft.eval() 
with torch.no_grad():
    pred = model_ft(dat)

Thank you for your attention. This problem also make output in hook function increase, every time it runs, it increases by 1. So the print in hook function prints 10 lines more.

You are still appending to the global list, so the list will grow in each forward pass.
To store only the activation of the last forward pass you would need to reinitialize the list or clear it in any other way:

with torch.no_grad():
    out = model_ft(torch.randn(1, 3, 224, 224))
    print(len(total_feat_out)) # prints 10
    total_feat_out = [] # reinitialize list to delete old activations
1 Like

Thank you very much. I finally get what I want!!

I avoid this problem by using dict as you told me to replace the outputs.

model_ft = models.resnet18(pretrained=True)

feature_out = {}
layers_name = list(model_ft._modules.keys())
layers = list(model_ft._modules.values())


def hook_fn_forward(module, input, output):
    layer = layers_name[np.argwhere([module == m for m in layers])[0, 0]]
    total_feat_out[layer] = output
    

modules = model_ft.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)

model_ft.eval() 
with torch.no_grad():
    pred = model_ft(dat)

Also, from one of your past answer, I found that I didn’t remove the hooks, which essencially caused the problems.

total_feat_out = []

def hook_fn_forward(module, input, output):
    total_feat_out.append(output) 
    print(output.shape)

modules = model_ft.named_children()
handles = {}
for name, module in modules:
    handles[name] = module.register_forward_hook(hook_fn_forward)

model_ft.eval() 
with torch.no_grad():
    pred = model_ft(dat)
    for k, v in handles.items():
        handles[k].remove()
``
1 Like