nn.BatchNorm2d with shared weights

Hello

Lets say that in my CNN network I have a feature map with the dimensions N,C,H,W
Lets say that H=50 and W=4
How can I pass the feature map through a nn.BatchNorm2d with dimensions H=10 and 4

I do not want to instantiate 5 distinct nn.BatchNorm2d.
Instead I want the learnable parameters of nn.BatchNorm2d shared.

best regards

Hello

Here is my solution

def forward(self,x):        
    
    out = self.cnn1(x)
    
    outt = self.batchnorm1(out[:, :, 0 : int(out.shape[2]/self.n), :])        
    for i in range(1,self.n):
        outt = torch.cat( (outt, self.batchnorm1(out[:, :, i*int(out.shape[2]/self.n) : (i+1)*int(out.shape[2]/self.n), :]) ), dim=2)
    out = self.Dropout(self.LeReLu( out ))


It works in the training script.
Does anyone want to state any problems in terms of practicality?

You approach should work. I would probably append the outputs to a list and call torch.stack or torch.cat afterwards to avoid concatenating the intermediates in each iteration.

Thank you for the answer

I made 3 tests with the architectures below

def forward(self,x):

out = self.cnn1(x)
out = self.Dropout(self.LeReLu(self.batchnorm1( out )))
...

This is the standard connection for BatchNorm2d.
GPU memory usage is 1.8GB during training.
Both training and validation losses of the network decrease as expected.

def forward(self,x):

out = self.cnn1(x)

outt = self.batchnorm1(out[:, :, 0 : int(out.shape[2]/self.n), :])        
for i in range(1,self.n):
	outt = torch.cat( (outt, self.batchnorm1(out[:, :, i*int(out.shape[2]/self.n) : (i+1)*int(out.shape[2]/self.n), :]) ), dim=2)
out = self.Dropout(self.LeReLu( out ))
...

This is my approach for connecting BatchNorm2d with shared weights.
At the beginning of the training, the GPU memory usage jumps to 3.8GB and is stable at 2.2GB during training.
Training loss decreases but validation loss does not reduce.

def forward(self,x):

out = self.cnn1(x)

outt = []
for i in range(self.n):
	outt.append( self.batchnorm1(out[:, :, i*int(out.shape[2]/self.n) : (i+1)*int(out.shape[2]/self.n), :]) )
out = self.Dropout(self.LeReLu( torch.cat(outt) ))

This is the approach you advised to connect BatchNorm2d with shared weights.
The GPU memory usage is stable at 2GB during training.
Both training loss and validation loss do not reduce.

It is clear that the network inferred with your approach is better in terms of GPU memory utilization.
However, the problem is for both my approach and your approach the validation loss does not decrease.

Mathematically these approaches are of course not equal, since the same out tensor is used completely and is sliced in the other approaches. This will change the normalization of the forward activation as thus as well the running stats of the layer.
I understood you want to use “weight sharing” on different tensors, but it seems you are now comparing different slicing approaches for the same input tensor?

Thank you for the answer

I am trying to use the same nn.BatchNorm2d layer to normalize different areas of feature maps.
Let’s say that the dimensions of the feature map are N,C,H, and W and in H direction there are repetitive data. While normalizing the feature map I do not want to cross-normalize these repetitive patterns. Moreover, I want to use nn.BatchNorm2d with the same parameters over repetitive patterns in the H direction. Therefore I try to pass consecutive regions of feature maps over the same nn.BatchNorm2d.

Using lists and concatenating the lists is the correct way to do that as you suggested.

In the 3rd test, I realized that dim argument is missing in cat() function.

Yes, concatenating the inputs allows you to reuse the same layer, but you cannot expect to see the same results as previously explained.
Normalizing the “full” image in [H, W] will not yield the same result as normalizing 4 patches of the image in [H//4, W//4].
Here is a small artificial example which shows the completely different results:

# setup
x1 = torch.zeros(1, 1, 24, 24)
x2 = torch.ones(1, 1, 24, 24)
x = torch.cat((x1, x2), dim=2)

# full image
bn = nn.BatchNorm2d(1)

out_all = bn(x)
plt.imshow(out_all[0, 0].detach().numpy())
print(out_all.min(), out_all.max(), out_all.mean())
# tensor(-1.0000, grad_fn=<MinBackward1>) tensor(1.0000, grad_fn=<MaxBackward1>) tensor(0., grad_fn=<MeanBackward0>)

print(bn.running_mean)
# tensor([0.0500])
print(bn.running_var)
# tensor([0.9250])

# window approach
bn = nn.BatchNorm2d(1)
out = torch.cat([bn(x_) for x_ in x.split(24, dim=2)], dim=2)
plt.imshow(out[0, 0].detach().numpy())
print(out.min(), out.max(), out.mean())
# tensor(0., grad_fn=<MinBackward1>) tensor(0., grad_fn=<MaxBackward1>) tensor(0., grad_fn=<MeanBackward0>)

print(bn.running_mean)
# tensor([0.1000])
print(bn.running_var)
# tensor([0.8100])

Thank you very much for the answer.

You are absolutely right.
Different architectures will generate different outputs as shown in your example.

While interpreting the tests that I made, I just indicated their performance in terms of convergence of training and validation losses. That was a very rough evaluation.

Just changing the normalization regions of a single activation map in a network can easily disrupt the convergence characteristic of the network completely.