Hi, I was curios to understand this problem: I am training a custom version of StyleGAN2-ada and during the discriminator training loop, every 4 steps, the optimizer.step() takes a very long time. Here a snippet of my code for the discriminator training step:
self.discriminator_optimizer.zero_grad()
# Accumulate gradients for `gradient_accumulate_steps`
for i in range(self.args.gradient_accumulate_steps):
if idx % self.args.p_update_interval == 0:
self.p = self.p_sched.get_p()
self.augmenter.update_p(self.p)
self.logger.add_scalar("[TRAIN] ADA PROB (p)", self.p, global_step=idx)
# Sample images from generator
generated_images, _ = self.generate_images(self.args.batch_size)
#Apply ADA
generated_images, _ = self.augmenter(generated_images, args=self.args)
# Discriminator classification for generated images
fake_output = self.discriminator(generated_images.detach())
self.avg_log["avg_out_D_fake"] += torch.mean(fake_output).item()
# Get real images from the data loader
real_images, patches = next(self.loader)
real_images = real_images.to(self.args.device)
#APPLY ADA
#plt.subplot(2,1,1)
#plot_grid_images(real_images, nrow=8, normalize=True)
real_images, _ = self.augmenter(real_images, debug=False, args=self.args)
#plt.subplot(2,1,2)
#plot_grid_images(real_images, nrow=8, normalize=True)
#plt.show()
# We need to calculate gradients w.r.t. real images for gradient penalty
if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
real_images.requires_grad_()
# Discriminator classification for real images
real_output = self.discriminator(real_images)
self.avg_log["avg_out_D_real"] += torch.mean(real_output).item()
self.p_sched.step(real_output) #update the p scheduler
# Get discriminator loss
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
disc_loss = real_loss + fake_loss
# Add gradient penalty
if (idx + 1) % self.args.lazy_gradient_penalty_interval == 0:
# Calculate and log gradient penalty
gp = self.gradient_penalty(real_images, real_output)
# Multiply by coefficient and add gradient penalty
disc_loss = disc_loss + 0.5 * self.args.gradient_penalty_coefficient * gp * self.args.lazy_gradient_penalty_interval
# Log discriminator loss
self.logger.add_scalar("[TRAIN][DISC] LOSS", disc_loss.item(), global_step=idx)
self.avg_log["avg_loss_D"] += disc_loss.item()
self.logger.add_scalars("[TRAIN][DISC] OUTPUT", {'REAL': torch.mean(real_output).item(), 'FAKE': torch.mean(fake_output).item()}, global_step=idx)
if (idx + 1) % self.args.log_generated_interval == 0:
# Log discriminator model parameters occasionally
self.logger.add_scalar("[TRAIN][DISC] LOSS AVG", self.avg_log["avg_loss_D"] / self.args.log_generated_interval, global_step=idx)
self.avg_log["avg_loss_D"] = 0
######### VALIDATION #######
############################
self.discriminator.eval()
self.avg_log["avg_out_D_val"] = 0
for i, data in enumerate(self.dataloader_val):
real_images_val, patches_val = data
real_images_val = real_images_val.to(self.args.device)
val_output = self.discriminator(real_images_val)
self.avg_log["avg_out_D_val"] += torch.mean(val_output).item()
self.discriminator.train()
self.logger.add_scalars("[VAL] DISC OUT AVG", {'VAL': self.avg_log["avg_out_D_val"]/self.args.log_generated_interval, 'REAL': self.avg_log["avg_out_D_real"]/self.args.log_generated_interval, 'FAKE': self.avg_log["avg_out_D_fake"]/self.args.log_generated_interval}, global_step=idx)
self.avg_log["avg_out_D_val"] = 0
self.avg_log["avg_out_D_fake"] = 0
self.avg_log["avg_out_D_real"] = 0
# Compute gradients
disc_loss.backward()
# Clip gradients for stabilization
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
print("init")
start = time.time()
# Take optimizer step
self.discriminator_optimizer.step()
end = time.time()
print("end: ", end-start)
The output is:
init
end: 0.01297616958618164
0%| | 1/150000 [00:04<172:42:58, 4.15s/it]init
end: 0.0026178359985351562
0%| | 2/150000 [00:05<94:45:09, 2.27s/it]init
end: 0.005979061126708984
0%| | 3/150000 [00:06<68:58:23, 1.66s/it]init
end: 12.416852474212646
0%| | 4/150000 [00:19<268:31:27, 6.44s/it]init
end: 0.00793766975402832
0%| | 5/150000 [00:20<184:55:54, 4.44s/it]init
end: 0.007531881332397461
0%| | 6/150000 [00:21<134:06:41, 3.22s/it]init
end: 0.003564119338989258
0%| | 7/150000 [00:24<127:06:31, 3.05s/it]init
end: 12.370309114456177
0%| | 8/150000 [00:37<268:12:15, 6.44s/it]init
end: 0.007948160171508789
0%| | 9/150000 [00:38<196:22:54, 4.71s/it]init
end: 0.005285978317260742
0%| | 10/150000 [00:39<147:46:37, 3.55s/it]init
end: 0.008272171020507812
0%| | 11/150000 [00:40<113:58:37, 2.74s/it]init
end: 12.416625261306763
0%| | 12/150000 [00:54<253:27:49, 6.08s/it]
This is really annoying since it slows down a lot the training time. In the generator training I don’t see any slow down apparently. I think it is really strange. Has anyone some ideas to fix the problem?
Thanks a lot in advance!