First I use a pytorch pretrained Resnet, then I use these codes to get the hidden feature.
feat_out = []
def hook_fn_forward(module, input, output):
feat_out.append(output)
print(output)
modules = model.named_children()
for name, module in modules:
module.register_forward_hook(hook_fn_forward)
pred = model(x)
But when I run these codes the first time, len(feat_out)
gives me 10, and print
in hook function prints 10 lines. If I run these codes again, len(feat_out)
gives me 20, and print
in hook function prints 20 lines. Every time I run, the length of output
in hook function increase by 1. The output is output in this time plus all past output. Only if I reinitialize the model and run these codes, the past output history will be removed.
How can I clear the output every time I run the model?
I use these codes in colab to reproduce this problem in minimum length (5 lines to load data, 2 lines to initialize model, 8 lines for this problem).
Since you are initializing the feat_out
list
as a global object, this behavior is expected.
You could either reinitialize it after the first forward pass or use a dict
instead, which can be easily used to replace the outputs.
Also note, if you are not running the forward pass in a with torch.no_grad()
block, the output
tensors will stay attached to the computation graph, which will prevent PyTorch from deleting this graph after the backward
call. If you don’t need these output
tensors for a gradient calculation you could alternatively store the detached version via feat_out.append(output.detach())
.
2 Likes
These are my full code, in case you don’t want to go to Colab.
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip -q hymenoptera_data.zip
data_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ToTensor()])
data_dir = 'hymenoptera_data//train'
image_datasets = datasets.ImageFolder(os.path.join(data_dir), transform=data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, num_workers=4)
model_ft = models.resnet18(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 2) # size of each output is set to 2
dat, lab = next(iter(dataloaders))
total_feat_out = []
def hook_fn_forward(module, input, output):
total_feat_out.append(output)
print(output.shape)
modules = model_ft.named_children()
for name, module in modules:
module.register_forward_hook(hook_fn_forward)
model_ft.eval()
with torch.no_grad():
pred = model_ft(dat)
Thank you for your attention. This problem also make output in hook function increase, every time it runs, it increases by 1. So the print
in hook function prints 10 lines more.
You are still appending to the global list, so the list will grow in each forward pass.
To store only the activation of the last forward pass you would need to reinitialize the list or clear it in any other way:
with torch.no_grad():
out = model_ft(torch.randn(1, 3, 224, 224))
print(len(total_feat_out)) # prints 10
total_feat_out = [] # reinitialize list to delete old activations
1 Like
Thank you very much. I finally get what I want!!
I avoid this problem by using dict
as you told me to replace the outputs.
model_ft = models.resnet18(pretrained=True)
feature_out = {}
layers_name = list(model_ft._modules.keys())
layers = list(model_ft._modules.values())
def hook_fn_forward(module, input, output):
layer = layers_name[np.argwhere([module == m for m in layers])[0, 0]]
total_feat_out[layer] = output
modules = model_ft.named_children()
for name, module in modules:
module.register_forward_hook(hook_fn_forward)
model_ft.eval()
with torch.no_grad():
pred = model_ft(dat)
Also, from one of your past answer, I found that I didn’t remove the hooks, which essencially caused the problems.
total_feat_out = []
def hook_fn_forward(module, input, output):
total_feat_out.append(output)
print(output.shape)
modules = model_ft.named_children()
handles = {}
for name, module in modules:
handles[name] = module.register_forward_hook(hook_fn_forward)
model_ft.eval()
with torch.no_grad():
pred = model_ft(dat)
for k, v in handles.items():
handles[k].remove()
``
1 Like