For my variable length input I have attempted to use a mask to avoid padding messing with the batch statistics.

x[~mask] = self.norm(x[~mask])

Here in this toy example x is of dimensionality (batch,embedding) and the mask is of dimensonality (batch) and is true where real data is and false where padding is. I have implemented this strategy and it seems extremely unstable. Is there something wrong with the way I’m going about this? Is there a better way to do a “partial” batch norm like this where padding is ignored?

I’m not sure if there is a better approach than to manually implement the (batch)norm layer and applying your custom logic to calculate the stats.
For a tensor of [batch_size, channels, height, width], you could split the tensor in the channels dimension, mask the spatial dimensions separately, and calculate the stats from it.