Hi guys,
I want to disable amp
of BN by decorating its forward function, as,
keep_forward_float_(self.bn)
where
def keep_forward_float_(m):
def float_forward(self, x, forward):
assert isinstance(self, nn.Module)
with autocast(enabled=False):
return forward(x.float())
m.forward = MethodType(functools.partial(float_forward, forward=m.forward), m)
Is this implementation good?
Your answer and guide will be appreciated!