BatchNorm in RNN Step


I had pre-trained a RNN with BatchNormalization with collected data. In training, the input was in sizes of {batchSize, seqLen, others} so that the num_features of BatchNorm layer is seqLen (a fixed number). Then I want to transfer the RNN into A3C and train it online. In this case, the input is {1, 1, others} in real-time action generation, and the num_features of BatchNorm should be 1. But the BatchNorm(num_features = seqLen) could still be used in offline network updating.

How to deal with the case? Could I just ignore the BatchNorm layer in real-time forward? Do I have to train the net in step-by-step way?

Thank you very much!

If you are dealing with a different number of features for the batchnorm layer, you could add a condition into your model, which would then pick the right batchnorm layer for the current input.

I’m not familiar with your use case, but maybe it would also make sense to slice and copy the batchnorm layer parameters and buffers from the seqLen use case to the single feature use case?

@ptrblck. “slice and copy batchnorm layer parameters and buffers” works.

Is there any official/widely-used way to do “slice and copy batchnorm layer” instead of manipulating “weights/bias” tensor values directly?

Thank you very much!

You could use something like:

small_bn = nn.BatchNorm2d(1)
with torch.no_grad():
    small_bn.weight = nn.Parameter(large_bn.weight[0:1]) # or take the mean, slice at another pos?
    # same for bias, running_mean and running_var

I don’t know, if that would make sense at all for your use case, but you might consider it. :slight_smile: