How do I apply Batch Normalization on a sequential data?

I am trying to train an NLP model that takes in the entire sequence at once instead of passing each time step individually as this approach is faster afaik. So, my data is of shape (seq_size, batch_size, length). This format is working fine for all other layers but I am facing problem with BatchNorm1d. So, normally when I am using a non-sequential data of shape (batch_size, length), I do the following:

self.linear = nn.Linear(length1, length2)
self.bn = nn.BatchNormalization1d(length2)

and then in the forward function:

x = self.linear(x)
x = self.bn(x)

and this works fine. But when I use the same method for sequential, it takes the second dimension, i.e. the batch_size and then gives a RuntimeError. eg, my batch size is 32, and the length of output from the previous Linear layer is 64, so I have defined my batch norm as nn.BatchNorm1d(64). but this gives me error RuntimeError: running_mean should contain 32 elements not 64 on forward pass. How do I fix this? what is the correct method for this kind of problem?

Try to permute your input so that it has the shape [batch_size, seq_len, features] and rerun the code again.
nn.Linear and nn.BatchNorm*d expect the batch dimension to be in dim0, which might yield the error you are seeing.

1 Like

will it give the same result as using each timestep individually and then combining? like will permute change and calculations that would have taken place if I passed the sequence one by one?

The linear layer will be applied to the additional input dimensions (the * in [batch_size, *, in_features]) as if you would pass the input sequentially in dim*.

The batchnorm layer will work differently, since you are currently using the batch size as the feature dimension, which sounds wrong.

I think the simplest solution is to treat the sequence and the batch dimensions equally. So you could do:

x = self.bn(x.reshape(seq_size*batch_size, length)).reshape(x.shape)

from the perspective of self.bn the batch size is seq_size*batch_size using this method.