Batchnorm for iterative or recurrent network

I am trying to design a network which contains a recurrent sub-network. The input of this sub-network will change over iterations(because it current input depends on its last output). However, I find the network works much worse in the evaluation mode than the training mode. Through debugging, I find it’s the problem of batchnorm statistics, i.e. running_mean and running_var. As stated in recent paper “recurrent batch normalization” (https://arxiv.org/pdf/1603.09025.pdf), the batchnorm statistics should be done separately for each iterations. I have also saw one implementation of this paper https://github.com/jihunchoi/recurrent-batch-normalization-pytorch/blob/master/bnlstm.py, but I don’t think it works since it doesn’t implement the important separate batch stastics for different iterations. Does someone know how to implement this or other tricks that have the equivalent effect?

Some people may suggest to create multiplt BN modules and use the first one for the first iteration, second one for the second iteration, and so on. The problem is that, as pointed in the paper mentioned above, the learnable parameters gamma and beta should be shared between these BN modules. Only the running_mean and running_var should be differentiated by the iterations. If do it in this way, how do we force them to share the learnable weights gamma and beta?

you can create the modules and the manually set the gamma and beta of all the modules to just one of them.
For example:

a = nn.BatchNorm2d(...)
b = nn>BatchNorm2d(...)
a.gamma = b.gamma
a.beta = b.beta