Optimizer.step() occasionaly takes very long time

Hi, I was curios to understand this problem: I am training a custom version of StyleGAN2-ada and during the discriminator training loop, every 4 steps, the optimizer.step() takes a very long time. Here a snippet of my code for the discriminator training step:

self.discriminator_optimizer.zero_grad()

        # Accumulate gradients for `gradient_accumulate_steps`
        for i in range(self.args.gradient_accumulate_steps):
            
            if idx % self.args.p_update_interval == 0:
                self.p = self.p_sched.get_p() 
                self.augmenter.update_p(self.p)
                self.logger.add_scalar("[TRAIN] ADA PROB (p)", self.p, global_step=idx)

            # Sample images from generator
            generated_images, _ = self.generate_images(self.args.batch_size)
            #Apply ADA
            generated_images, _ = self.augmenter(generated_images, args=self.args)
            # Discriminator classification for generated images
            fake_output = self.discriminator(generated_images.detach())
            self.avg_log["avg_out_D_fake"] += torch.mean(fake_output).item()

            # Get real images from the data loader
            real_images, patches = next(self.loader)
            real_images = real_images.to(self.args.device)

            #APPLY ADA
            #plt.subplot(2,1,1)
            #plot_grid_images(real_images, nrow=8, normalize=True)
            real_images, _ = self.augmenter(real_images, debug=False, args=self.args)
            #plt.subplot(2,1,2)
            #plot_grid_images(real_images, nrow=8, normalize=True)
            #plt.show()

            # We need to calculate gradients w.r.t. real images for gradient penalty
            if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
                real_images.requires_grad_()
            # Discriminator classification for real images
            real_output = self.discriminator(real_images)
            self.avg_log["avg_out_D_real"] += torch.mean(real_output).item()

            self.p_sched.step(real_output) #update the p scheduler

            # Get discriminator loss
            real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
            disc_loss = real_loss + fake_loss

            # Add gradient penalty
            if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
                # Calculate and log gradient penalty
                gp = self.gradient_penalty(real_images, real_output)
                # Multiply by coefficient and add gradient penalty
                disc_loss = disc_loss + 0.5 * self.args.gradient_penalty_coefficient * gp * self.args.lazy_gradient_penalty_interval


            # Log discriminator loss
            self.logger.add_scalar("[TRAIN][DISC] LOSS", disc_loss.item(), global_step=idx)
            
            self.avg_log["avg_loss_D"] += disc_loss.item()

            self.logger.add_scalars("[TRAIN][DISC] OUTPUT", {'REAL': torch.mean(real_output).item(), 'FAKE': torch.mean(fake_output).item()}, global_step=idx)

            if (idx + 1) % self.args.log_generated_interval == 0:
                # Log discriminator model parameters occasionally
                self.logger.add_scalar("[TRAIN][DISC] LOSS AVG", self.avg_log["avg_loss_D"] / self.args.log_generated_interval, global_step=idx)
                self.avg_log["avg_loss_D"] = 0

                ######### VALIDATION #######
                ############################
                self.discriminator.eval()
                self.avg_log["avg_out_D_val"] = 0

                for i, data in enumerate(self.dataloader_val):
                    real_images_val, patches_val = data
                    real_images_val = real_images_val.to(self.args.device)
                    
                    val_output = self.discriminator(real_images_val)
                    self.avg_log["avg_out_D_val"] += torch.mean(val_output).item()

                self.discriminator.train()

                self.logger.add_scalars("[VAL] DISC OUT AVG", {'VAL': self.avg_log["avg_out_D_val"]/self.args.log_generated_interval, 'REAL': self.avg_log["avg_out_D_real"]/self.args.log_generated_interval,  'FAKE': self.avg_log["avg_out_D_fake"]/self.args.log_generated_interval}, global_step=idx)
                self.avg_log["avg_out_D_val"] = 0
                self.avg_log["avg_out_D_fake"] = 0
                self.avg_log["avg_out_D_real"] = 0


            # Compute gradients
            disc_loss.backward()


            # Clip gradients for stabilization
            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
            
            print("init")
            start = time.time()
            
            # Take optimizer step
            self.discriminator_optimizer.step()

            end = time.time()
            print("end: ", end-start)

The output is:

init
end:  0.01297616958618164
  0%|                                                        | 1/150000 [00:04<172:42:58,  4.15s/it]init
end:  0.0026178359985351562
  0%|                                                         | 2/150000 [00:05<94:45:09,  2.27s/it]init
end:  0.005979061126708984
  0%|                                                         | 3/150000 [00:06<68:58:23,  1.66s/it]init
end:  12.416852474212646
  0%|                                                        | 4/150000 [00:19<268:31:27,  6.44s/it]init
end:  0.00793766975402832
  0%|                                                        | 5/150000 [00:20<184:55:54,  4.44s/it]init
end:  0.007531881332397461
  0%|                                                        | 6/150000 [00:21<134:06:41,  3.22s/it]init
end:  0.003564119338989258
  0%|                                                        | 7/150000 [00:24<127:06:31,  3.05s/it]init
end:  12.370309114456177
  0%|                                                        | 8/150000 [00:37<268:12:15,  6.44s/it]init
end:  0.007948160171508789
  0%|                                                        | 9/150000 [00:38<196:22:54,  4.71s/it]init
end:  0.005285978317260742
  0%|                                                       | 10/150000 [00:39<147:46:37,  3.55s/it]init
end:  0.008272171020507812
  0%|                                                       | 11/150000 [00:40<113:58:37,  2.74s/it]init
end:  12.416625261306763
  0%|                                                       | 12/150000 [00:54<253:27:49,  6.08s/it]

This is really annoying since it slows down a lot the training time. In the generator training I don’t see any slow down apparently. I think it is really strange. Has anyone some ideas to fix the problem?

Thanks a lot in advance!