How to implement a split batchnorm in PyTorch?

I read the following discussion. https://discuss.pytorch.org/t/how-does-one-make-sure-that-the-parameters-are-update-manually-in-pytorch-using-modules/6076

Can we implement such a feature without replacing the nn.Sequential()?