amp.Gradscaler().step(optimizer) is taking too much time!

So I have created a class TrainEpoch to train a single epoch , where I defined two functions for the forward and the backward pass as follows:

class TrainEpoch():
    def Forward(self,x,y):        
            self.pre_logits = self.model(x)
            self.pre_logits = _s(self.pre_logits,self.device) #send logits to device
            self.curr_loss = self.loss_func(self.pre_logits, y)
    def Backward(self):

This class is working fine, however I created a subclass TrainEpoch_AMP from the above one to use automatic mixed precision training, and set the forward and backward pass function as follows:

import torch.cuda.amp as AMP
scaler = AMP.GradScaler()

class TrainEpochAMP(TrainEpoch):
    self.scaler = scaler
    def Forward(self, x, y):
        super().Forward(x, y)
    def Backward(self):

when i try to use amp training , the time needed to complete one epoch becomes 22hours while it is only 30 minutes without amp…
After inspection, i found that self.scaler.step(self.optimizer) needs ~17 seconds to execute!!!
what could be the problem?

The GradScaler shouldn’t add this massive overhead, as it’ll check for invalid gradients and skip the optimiyer.step() if necessary. Could you post a minimal code snippet to reproduce this issue, so that we could debug it, please?