For my variable length input I have attempted to use a mask to avoid padding messing with the batch statistics.
x[~mask] = self.norm(x[~mask])
Here in this toy example x is of dimensionality (batch,embedding) and the mask is of dimensonality (batch) and is true where real data is and false where padding is. I have implemented this strategy and it seems extremely unstable. Is there something wrong with the way I’m going about this? Is there a better way to do a “partial” batch norm like this where padding is ignored?
1 Like
The masked tensor should be a flattened one so I’m unsure how your normalization layer would work with this type of input:
x = torch.randn(10, 10)
mask = torch.randint(0, 2, (10, 10)).bool()
print(x[mask].shape)
> torch.Size([49])
That would explain why it wasnt working!
Is there an easy alternative? I’m trying to do standard normalization I just want the padding excluded from batch statistics.
I’m not sure if there is a better approach than to manually implement the (batch)norm layer and applying your custom logic to calculate the stats.
For a tensor of [batch_size, channels, height, width]
, you could split the tensor in the channels
dimension, mask the spatial dimensions separately, and calculate the stats from it.