Batchnorm for different sized samples in batch

I use batchnorm 1d on batches which are padded to the max length of the samples. It dawned on me that batch norm isn’t fed a mask so it has no way of knowing which are valid timesteps in each sequence. Wouldn’t this mess with batch norm? And more importantly wouldn’t it be very different if I change batch size? Is there a way around this?

1 Like

Hi @Dan_Erez! Did you find anything to solve the problem? I’m facing the same problem. Is there a solution similar to the one in tensorflow here in PyTorch?

1 Like

@ptrblck Please, can you help me with this?
Basically, I have a tensor padded with zeros in the end. If I feed this into torch.nn.BatchNorm1d it will consider those as well. I also have a mask (binary) for the padded tensor. Is there something in PyTorch to tackle this?

As far as I understand your use case, you are creating a batch (3-dimensional: [batch_size, channels, seq_len]), where some tensors were zero-padded in the last dimension.
Is that correct?
Now you would like to ignore the padded inputs in the batchnorm layer, i.e. not being taken into account for the running stats or what would the desired behavior be?

1 Like

Thanks for the reply! :slight_smile:
Yes precisely. This is exactly what I need. If the zeros are taken into account, it will be wrong. Need to ignore those.

I’m not aware of any built-in method, so you might need to implement it manually.

Maybe you could use this manual example of the batch norm calculation as a starter and change the mean and var calculation using the masked method:

# Create dummy input
x = torch.randn(2, 3, 10)
x[0, 0, 5:] = 0.
x[0, 1, 6:] = 0.
x[0, 2, 7:] = 0.
x[1, 0, 8:] = 0.
x[1, 1, 9:] = 0.

# Use mask for manual calculation
mask = x!=0
mask_mean = (x.sum(2) / mask.float().sum(2)).mean(0)

# Alternatively rescale BatchNorm1d.running_mean
mean = x.mean([0, 2])
mean * x.size(2) / (x.size(2) - (x==0).float().sum(2)).sum()

The second example would work, if you would like to use the PyTorch batch norm implementation (e.g. for performance reasons) and “rescale” the running estimates.

Let me know, if that helps.

3 Likes

thanks for the reply. can you explicity show how you would operate on the batch norm parameter?

I wrote a solution to do this fast, explained as comments in the code. Let me know if you find any bugs.

    def masked_batchnorm1d_forward(x, mask, bn):
        """x is the input tensor of shape [batch_size, n_channels, time_length]
            mask is of shape [batch_size, 1, time_length]
            bn is a BatchNorm1d object
        """
        if not self.training:
            return bn(x)

        # In each example of the batch, we can have a different number of masked elements
        # along the time axis. It would have to be represented as a jagged array.
        # However, notice that the batch and time axes are handled the same in BatchNorm1d.
        # This means, we can merge the time axis into the batch axis, and feed BatchNorm1d
        # with a tensor of shape [n_valid_timesteps_in_whole_batch, n_channels, 1],
        # as if the time axis had length 1.
        #
        # So the plan is:
        #  1. Move the time axis next to the batch axis to the second place
        #  2. Merge the batch and time axes using reshape
        #  3. Create a dummy time axis at the end with size 1.
        #  4. Select the valid time steps using the mask
        #  5. Apply BatchNorm1d to the valid time steps
        #  6. Scatter the resulting values to the corresponding positions
        #  7. Unmerge the batch and time axes
        #  8. Move the time axis to the end 

        n_feature_channels = x.shape[1]
        time_length = x.shape[2]
        reshaped = x.permute(0, 2, 1).reshape(-1, n_feature_channels, 1)
        reshaped_mask = mask.reshape(-1, 1, 1) > 0
        selected = torch.masked_select(reshaped, reshaped_mask).reshape(-1, n_feature_channels, 1)
        batchnormed = bn(selected)
        scattered = reshaped.masked_scatter(reshaped_mask, batchnormed)
        backshaped = scattered.reshape(-1, time_length, n_feature_channels).permute(0,2,1)
        return backshaped
1 Like

Batchnorm1d has learnable parameters. Is there any problem with having different sized batches?
Thank you