Pretrained VGG: In-place relu prevents access of layer outputs preceding the relu

I try to access intermediate layer’s activations in pretrained vgg. I’d like to access the activations after the BN layers. But since the BN layer is followed by a in-place relu, accessing the BN layer yields relu activated outputs, i.e. negative values are clamped to zero. Why is this the case? See code:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import vgg19_bn as vgg_net


class VGG_acts(nn.Module):

    def __init__(self, layer_numbers):
        super(VGG_acts, self).__init__()

        self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        self.stdv = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

        self.vgg = vgg_net(pretrained=True).eval()

        print(self.vgg.features)

        self.layer_acts = []
        self.forward_hooks = []
        
        #register hooks
        for layer_num, layer_name in enumerate(self.vgg.features._modules.keys()):
            # print(layer_num, layer_name)
            if layer_num in layer_numbers:
                self.forward_hooks.append(getattr(self.vgg.features, layer_name).register_forward_hook(self.get_activations(layer_name)))


    #Defining hook to get intermediate features     
    def get_activations(self, layer_name):
        def hook(module, input, output):
            self.layer_acts.append(output)
        return hook


    def forward(self, x):
        
        # range norm
        x = x/255.
        
        # input norm
        x = (x - self.mean) / self.stdv
        out = self.vgg(x)

        return self.layer_acts

layer_numbers = [1,]

vgg = VGG_acts(layer_numbers=layer_numbers)

t = torch.rand(8,3,256,256)

out = vgg(t)
print(out[0].min())


> output:
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace=True)
  (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU(inplace=True)
  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (16): ReLU(inplace=True)
  (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (19): ReLU(inplace=True)
  (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (22): ReLU(inplace=True)
  (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (25): ReLU(inplace=True)
  (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (29): ReLU(inplace=True)
  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (32): ReLU(inplace=True)
  (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (35): ReLU(inplace=True)
  (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (38): ReLU(inplace=True)
  (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (42): ReLU(inplace=True)
  (43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (45): ReLU(inplace=True)
  (46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (48): ReLU(inplace=True)
  (49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (51): ReLU(inplace=True)
  (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
tensor(0., grad_fn=<MinBackward1>)

Ok, I could solve it by manually replacing the relu modules with standard nn.ReLU without in-place operation.
Are in-place operations fused with the previous operation in pytorch, or what is the reason for this behaviour?

No, in-place operations are not fused (and would often prevent fusion) and are manipulating the data directly inplace without creating a new output tensor instead.
This is why the output of the bn layer is changed by the inplace relu.
Here is a small example:

a = torch.tensor(1.)

# out-of-place
b = a + 1 
print(b)

# inplace
a += 1
print(a)

# also inplace
a.add_(1.)

I see, thank you. So the tensor after bn is directly manipulated by the in-place relu, i.e. the tensor returned by the hook is returned after the in-place relu manipulation.

Another thing I was wondering; why is in-place relu not a problem for backprop, however if we use other in-place operations there is an issue. Is it b.c. there is no gradient flow through the clamped outputs anyway due to zero slope in the relu for negative values?

It can be a problem and depends on the surrounding operations and in particular if the inplace relu manipulated a tensor inplace which is needed in its original form for the gradient calculation.
This post gives a small example using an inplace div operation.