I’m trying to train a StyleGAN2 model using BasicSR library.
My setup is torch1.8 with 4 NVIDIA GeForce 3090.
For every iteration it takes about 0.5s (including data loading, forward propagation and backward propagation). However, when comes to regularization for generator which happens every 4 iterations, the backward propagation becomes extremely slow. It takes average 25s to do the backward propagation.
Below is the code for generator regularization:
if current_iter % self.net_g_reg_every == 0:
path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink'])
noise = self.mixing_noise(path_batch_size, self.mixing_prob)
fake_img, latents = self.net_g(noise, return_latents=True)
l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)
l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
# TODO: why do we need to add 0 * fake_img[0, 0, 0, 0]
l_g_path.backward()
loss_dict['l_g_path'] = l_g_path.detach().mean()
loss_dict['path_length'] = path_lengths
The loss function “l_g_path”:
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()