The order of ReLU and BatchNorm in resnet50 during backProp

Hi, I was playing with resnet50 model pre-defined in torchvision.models, I use following code to register hooks

import torchvison.models as tmodels

def is_supported_instance(module):
    for typed_modules in get_supported_modules().values():
        if isinstance(module, tuple(typed_modules)):
            return True
    return False

def fw_hook(module, input, output):
    print(f'forward hook in module {module.__full_name__}')
    
    def bw_hook(grad):
        print(f'backward hook in module {module.__full_name__}')
        
    handle = output.register_hook(bw_hook)
    module.__bw_hook_handle__ = handle

def register_hook(module, hook_fn):
    if is_supported_instance(module):
        handle = module.register_forward_hook(hook_fn)
        module.__fw_hook_handle__ = handle
    
# tm = test_model()
tm = tmodels.resnet50()

for n, m in tm.named_modules():
    m.__full_name__ = n.replace('.', '/')
    
tm.apply(lambda m: register_hook(m, fw_hook))

input = torch.ones([2, 3, 224, 224], dtype = torch.float, requires_grad = True)
out = tm(input)
out = out.mean()
out.backward()

The result seems good in forward propagation part, however, in backward prop, I noticed that the order of ReLU and BatchNorm2d was reversed, as shown below


forward hook in module layer4/2/conv2
forward hook in module layer4/2/bn2
forward hook in module layer4/2/relu
forward hook in module layer4/2/conv3
forward hook in module layer4/2/bn3
forward hook in module layer4/2/relu
forward hook in module avgpool
forward hook in module fc
backward hook in module fc
backward hook in module avgpool
backward hook in module layer4/2/bn3
backward hook in module layer4/2/relu
backward hook in module layer4/2/conv3
backward hook in module layer4/2/bn2
backward hook in module layer4/2/relu
backward hook in module layer4/2/conv2

besides, I’ve checked the code for resnet50 in torchvision.models, it seems that ReLU always follows a BatchNorm2d layer. So why did this happen?
Thanks in advance.

Hi, so I am no expert but I assume that this is an optimization decision. If you think about what the RELU operation (e.g max(0, x)) does to the gradients you will notice that for negative inputs, it kills the gradient (turns it to 0) and for positive inputs, it just lets it flow since dx/dx=1. So if you put the RELU first and there are a lot of negative gradients flowing you will just kill them all and save computation in the backward pass.
Either that or there is a bug in your code and Im just making this up :sweat_smile:. Anyways I am not sure about this so maybe someone else can clarify further.

Hey thanks for the reply. I’ve thought about that too but if it was done for the purpose of reducing computation, shouldn’t it be

backward hook in module layer4/2/relu
backward hook in module layer4/2/bn3
backward hook in module layer4/2/conv3

which is the expected order as gradient props backward.
I’ve tried the register hook code on another toy model,

class test_model(nn.Module):
    def __init__(self):
        super(test_model, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, [3, 3])
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 3, [3, 3]) 
        self.bn2 = nn.BatchNorm2d(3) 
        self.relu = nn.ReLU()
        self.avgpool = nn.AdaptiveAvgPool2d(3)
        self.linear = nn.Linear(3, 1)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)
        out = self.linear(out)
        return out

for this model the result is as expected

forward hook in module conv1
forward hook in module bn1
forward hook in module relu
forward hook in module conv2
forward hook in module bn2
forward hook in module relu
forward hook in module avgpool
forward hook in module linear
backward hook in module linear
backward hook in module avgpool
backward hook in module relu
backward hook in module bn2
backward hook in module conv2
backward hook in module relu
backward hook in module bn1
backward hook in module conv1

so I really have no idea what’s going on there…

You are right. Since it was backward I read the prints from the bottom to the top and got confused :smile:

turns out “it’s because ReLU used in torchvision’s resnet50 is an inplace operation, given that the operation += just before the ReLU (exists in both BatchNorm and Linear) is also inplace, autograd engine does some extra checks and operation reordering during backward prop.”
I still don’t have a clue how the reordering is done, but I tried to change the ReLU in my toy model above to inplace version, and found that the order of ReLU and BatchNorm was indeed reversed during backward prop. So I guess it confirms the explanation.

1 Like