Center entire rows with the same mean and var with batchnorm2d in a 2d image?

Hello,

so basically, I am wondering whether it is possible to adapt the batchnorm2d layer in an easy way, in order to compute the mean and variance not for each pixel of an image across all batches, but rather for each row.

In my understanding, normally, we have input N,C,H,W where for each channel C and each pixel (i,j) (with i from 1,…,H and j from 1,…,W), we compute the mean(pixel(i,j) across each example n from 1,…,N in the batch) (as well as the variance) and center all pixels in the batch using this data.

Instead, I would rather like to center the batches using the mean and variance of each row i.

Thanks!

Best, JZ

Yes, you could implement your custom normalization layer using this approach.
Take a look at this manual implementation of batchnorm to see how the input activation is normalized and adapt it to your use case.

Nice, thanks, I’ll look into it and report back soon!

Dear @ptrblck,

as said, I get back with a simple implementation a custom norm. Maybe, you have the time to give it a glance.

The goal was, as said above, to normalize each row in each channel. Therefore, I compute the mean and std along the rows / dimension W and do the rescaling. Additionally, a linear transform like in your batchnorm is implemented. Should have an individual transform associated with each channel C and row H. Therefore, for the transform, I initialize weight and bias with shape (C,H), each.

My main questions are; is this the correct shape to return an individual transform for each channel and row? And furthermore, in general, how important is the linear transform in your experience? If the goal is to center the data and there are many layers in the model anyways, the extra transformation might not really be required, or is it often the case?

class CustomNorm(nn.Module): 

    def __init__(self,bchw,affine=True):
        super().__init__()
        self.bchw = bchw
        self.eps = 1e-5
        self.affine = affine
        if self.affine: 
            self.weight = nn.Parameter(torch.ones((bchw[1],bchw[2])))
            self.bias   = nn.Parameter(torch.zeros((bchw[1],bchw[2])))
        return

    def forward(self,x):
        # expects: B x C x H x W
        # B: batch size
        # C: convolutional channels
        # H: rows 1...H
        # W: time dimension

        # compute mean and var along specified dimension
        mu  = torch.mean(x,dim=-1).unsqueeze(-1)
        var = torch.var(x,dim=-1).unsqueeze(-1) 

        # normalize data
        xn = (x-mu)/torch.sqrt(var+self.eps)

        # linear transformation
        if self.affine: 
            xn = xn * self.weight[None,:,:,None] + self.bias[None,:,:,None]

        return xn

Below, a short example run. Random samples t are generated. they are transformed to y = weight * t + bias. Using affine = True, the customnorm layer should be able to recover weight and bias with its internal parameters self.weight and self.bias, which it does in this simple example.

# shape: B x C x H x W
bchw = (10,3,5,100)

# generate scaling and bias for computing y below
weight = torch.randint(1,10,(3,5))
bias   = torch.randint(-10,10,(3,5))

model = CustomNorm(bchw=bchw,affine=True)

def loss_fn(y,pred):
    return torch.mean((y-pred)**2)

optimizer = torch.optim.Adam(model.parameters(),lr=1e-2)

losses = []
for i in range(2500):
    t = torch.normal(0,1,bchw)
    y = weight[None,:,:,None]*t + bias[None,:,:,None]
    pred = model(t)
    loss = loss_fn(y,pred)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
with torch.no_grad():
    print('weight:')
    print(weight)
    print(model.weight.round().int())
    print()
    print('bias:')
    print(bias)
    print(model.bias.round().int())

This prints:

weight:
tensor([[8, 5, 2, 6, 5],
        [3, 6, 4, 8, 8],
        [4, 2, 2, 7, 7]])
tensor([[8, 5, 2, 6, 5],
        [3, 6, 4, 8, 8],
        [4, 2, 2, 7, 7]], dtype=torch.int32)

bias:
tensor([[ 0, -3, -9, -6,  1],
        [ 3,  0,  0, -1,  6],
        [ 2, -8, -2,  6,  7]])
tensor([[ 0, -3, -9, -6,  1],
        [ 3,  0,  0, -1,  6],
        [ 2, -8, -2,  6,  7]], dtype=torch.int32)

So, it seems that the layer works fine.

Best, JZ

Based on your description, I think it is the right approach for your use case, but I would write a quick test using a pre-defined input tensor with specific mean and std stats in each “row” and see if the custom batchnorm layer is normalizing them as expected.

I would guess they are important as they are used by default.
The idea behind the affine parameters is to give this layer the ability to “unnormalize” the data if needed using the trained parameters.