How to switch mixed-precision mode in training

Is it possible to dynamic change some layer’s amp_mode of one network in training process?
I use such method to set amp_mode attribute for all Linear && BN layers.

class MyDecorator:
    def __init__(self, func):
        self.func = func
        self.amp_mode = True

    def _to_float(self, x):
        if isinstance(x, SparseConvTensor):
            x = x.replace_feature(x.features.float())
        elif isinstance(x, torch.Tensor):
            x = x.float()
        return x
    
    def __call__(self, *args, **kwargs):
        if self.amp_mode:
            with autocast(enabled=True, dtype=torch.float16):  
                return self.func(*args, **kwargs)
        else:
            with autocast(enabled=False, dtype=torch.float32):
                args = [self._to_float(x) for x in args]
                kwargs = {k: self._to_float(v) for k, v in kwargs.items()}
                return self.func(*args, **kwargs)

def creat_forward_hook(name):
    def hook(module, input, output):
        print(f'{name}, is_amp?: {module.forward.amp_mode}, input_dtype: {input[0].dtype}, output_dtype: {output.dtype}')
    return hook

for name, module in model.named_modules():
    if isinstance(module, (nn.Linear, nn.BatchNorm1d)):
        module.forward = MyDecorator(module.forward)
        module.register_forward_hook(creat_forward_hook(name))

When the training begin, they are use amp training, in the training process, I change some of them which occurs nan value by setting forward.amp_mode=False. But It seems not really in FP32 comput mode? Because they will occurs nan again!
I training my network in 8 A100 GPUS. What else should I do beside these setting? Can you help me? THANKS!

@ptrblck

I don’t think using forward hooks will work as the output is already computed. You might want to disable amp in the forward method of your model manually instead.

Oh, the forward_hook is only used for print the dtype of inputs and outputs, it’s a little part of my project. In fact, I wrap MyDecorator for module’s forward function.
I set amp_mode to false when some layer output a number which is not finite. But It seems not really work because this layer output NAN again.
So how to dynamic set AMP mode in training process? Maybe need do something else in multi-gpu training?