Minimum value of Intermediate BatchNorm layer output clamps to 0

Hi,

I am doing an experiment where I use the output of intermediate values of batchnorm layer from resnet50. So, I forward hook at different layers and work on the output. But, the weird thing I observed is the min value of the batchnorm output clamps to zero at all the layers. To reproduce the result I wrote a toy program:

import torch
from torchvision import models, datasets
import torchvision
import datasets
from tqdm import tqdm, trange


activation_student = {}
def get_activation_student(name):
    def hook(self, input_, output):
        activation_student[name]=  output.data
    return hook

amazonData = datasets.amazon_load('train', 16)
model = models.resnet50(pretrained=True)
for n, m in model.named_modules():
    if n == 'layer2.1.bn1' or 'layer3.2.bn2' or 'layer4.1.bn3':
        m.register_forward_hook(get_activation_student(n))

for i in trange(5, leave=False):
    source_x, source_y = next(iter(amazonData))
    out = model(source_x)
    print("min: ", torch.min(activation_student['layer2.1.bn1']))
    print("min: ", torch.min(activation_student['layer3.2.bn2']))
    print("min: ", torch.min(activation_student['layer4.1.bn3']))

output:
min: tensor(0.)
min: tensor(0.)
min: tensor(0.)

can someone help me in fixing this, or is this expected ?

@ptrblck any idea?

P.S. I apologize for tagging.

The activation values are a bit tricky in this use case.
As you can see in this line of code, the Bottleneck module uses self.relu directly after the batch norm layers.
Since self.relu is defined as an inplace operation, the output of the batchnorm layers will be manipulated inplace, which is why you actually see the “output” of the self.relu layer.

Hi,

Thanks for your reply. I have gone through the docs. Had it been ‘layer2.1.bn3’ then I am convinced with your reason as there is a relu operation following it but after the batchnorm layer ‘layer2.1.bn1’ there is no relu operation. In resnet “conv-bn-conv-bn-conv-bn-relu” is the pattern followed and I am looking at the output of the first batchnorm where there is no ReLU(inplace=True). I hope I made it clear ?

But still, I get minimum value as “0” ? or am I missing something ?

The output of the Bottleneck layer might be misleading:

model.layer2[1]
Out[42]: 
Bottleneck(
  (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)

This output does not mean, that these modules will be executed sequentially, as they are not wrapped in an nn.Sequential module.
In fact the output just lists all modules as they were initialized in __init__, while the forward method defines the execution.
As you can see from the link I’ve posted, self.relu is used three times inside the Bottleneck module (every time operating on the activation inplace).

1 Like

That was very clear and to the point. Thanks a lot.

One final question, is there a way I could hook somehow and get the output of the batchnorm layer and not the relu output ? I cannot use conv layer’s output because my loss function involves a margin where I can set the margin to one std. dev if only I get output of batchnorm layer.

Thanks again.

The clean approach would be to derive a custom class from ResNet and use vanilla nn.ReLU modules instead of the inplace version.

The hacky way would be to just replace the “unwanted” inplace ReLUs:

model = models.resnet50(pretrained=True)
model.layer2[1].relu = nn.ReLU()
model.layer3[2].relu = nn.ReLU()
model.layer4[1].relu = nn.ReLU()
for n, m in model.named_modules():
    if n == 'layer2.1.bn1' or 'layer3.2.bn2' or 'layer4.1.bn3':   
        m.register_forward_hook(get_activation_student(n))
1 Like

:wink: Thanks my issue is solved. I will mark the above comment as solution.

1 Like