"Masked" Batchnorm for variable length input?

For my variable length input I have attempted to use a mask to avoid padding messing with the batch statistics.

x[~mask] = self.norm(x[~mask])

Here in this toy example x is of dimensionality (batch,embedding) and the mask is of dimensonality (batch) and is true where real data is and false where padding is. I have implemented this strategy and it seems extremely unstable. Is there something wrong with the way I’m going about this? Is there a better way to do a “partial” batch norm like this where padding is ignored?

The masked tensor should be a flattened one so I’m unsure how your normalization layer would work with this type of input:

x = torch.randn(10, 10)
mask = torch.randint(0, 2, (10, 10)).bool()
> torch.Size([49])

That would explain why it wasnt working!

Is there an easy alternative? I’m trying to do standard normalization I just want the padding excluded from batch statistics.

I’m not sure if there is a better approach than to manually implement the (batch)norm layer and applying your custom logic to calculate the stats.
For a tensor of [batch_size, channels, height, width], you could split the tensor in the channels dimension, mask the spatial dimensions separately, and calculate the stats from it.