Is it OK to disable `amp` of BN by decorating its forward function?

Hi guys,
I want to disable amp of BN by decorating its forward function, as,

keep_forward_float_(self.bn)

where

def keep_forward_float_(m):
    def float_forward(self, x, forward):
        assert isinstance(self, nn.Module)
        with autocast(enabled=False):
            return forward(x.float())

    m.forward = MethodType(functools.partial(float_forward, forward=m.forward), m)

Is this implementation good?

Your answer and guide will be appreciated!