How to extend Batchnorm2d?

I want to write a custom Batchnorm2d layer, is it enough to inherit from nn.Batchnorm2d and reimplement the forward() function? Something as below,

class CustomBatchNorm2d(nn.BatchNorm2d):
def forward(self, x):