Hi,
I am doing an experiment where I use the output of intermediate values of batchnorm layer from resnet50. So, I forward hook at different layers and work on the output. But, the weird thing I observed is the min value of the batchnorm output clamps to zero at all the layers. To reproduce the result I wrote a toy program:
import torch
from torchvision import models, datasets
import torchvision
import datasets
from tqdm import tqdm, trange
activation_student = {}
def get_activation_student(name):
def hook(self, input_, output):
activation_student[name]= output.data
return hook
amazonData = datasets.amazon_load('train', 16)
model = models.resnet50(pretrained=True)
for n, m in model.named_modules():
if n == 'layer2.1.bn1' or 'layer3.2.bn2' or 'layer4.1.bn3':
m.register_forward_hook(get_activation_student(n))
for i in trange(5, leave=False):
source_x, source_y = next(iter(amazonData))
out = model(source_x)
print("min: ", torch.min(activation_student['layer2.1.bn1']))
print("min: ", torch.min(activation_student['layer3.2.bn2']))
print("min: ", torch.min(activation_student['layer4.1.bn3']))
output:
min: tensor(0.)
min: tensor(0.)
min: tensor(0.)
can someone help me in fixing this, or is this expected ?
@ptrblck any idea?
P.S. I apologize for tagging.
The activation values are a bit tricky in this use case.
As you can see in this line of code, the Bottleneck
module uses self.relu
directly after the batch norm layers.
Since self.relu
is defined as an inplace operation, the output of the batchnorm layers will be manipulated inplace, which is why you actually see the âoutputâ of the self.relu
layer.
Hi,
Thanks for your reply. I have gone through the docs. Had it been âlayer2.1.bn3â then I am convinced with your reason as there is a relu operation following it but after the batchnorm layer âlayer2.1.bn1â there is no relu operation. In resnet âconv-bn-conv-bn-conv-bn-reluâ is the pattern followed and I am looking at the output of the first batchnorm where there is no ReLU(inplace=True). I hope I made it clear ?
But still, I get minimum value as â0â ? or am I missing something ?
The output of the Bottleneck
layer might be misleading:
model.layer2[1]
Out[42]:
Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
)
This output does not mean, that these modules will be executed sequentially, as they are not wrapped in an nn.Sequential
module.
In fact the output just lists all modules as they were initialized in __init__
, while the forward
method defines the execution.
As you can see from the link Iâve posted, self.relu
is used three times inside the Bottleneck
module (every time operating on the activation inplace).
1 Like
That was very clear and to the point. Thanks a lot.
One final question, is there a way I could hook somehow and get the output of the batchnorm layer and not the relu output ? I cannot use conv layerâs output because my loss function involves a margin where I can set the margin to one std. dev if only I get output of batchnorm layer.
Thanks again.
The clean approach would be to derive a custom class from ResNet
and use vanilla nn.ReLU
modules instead of the inplace version.
The hacky way would be to just replace the âunwantedâ inplace ReLUs:
model = models.resnet50(pretrained=True)
model.layer2[1].relu = nn.ReLU()
model.layer3[2].relu = nn.ReLU()
model.layer4[1].relu = nn.ReLU()
for n, m in model.named_modules():
if n == 'layer2.1.bn1' or 'layer3.2.bn2' or 'layer4.1.bn3':
m.register_forward_hook(get_activation_student(n))
1 Like
Thanks my issue is solved. I will mark the above comment as solution.
1 Like