I’m trying to train a StyleGAN2 model using BasicSR library.
My setup is torch1.8 with 4 NVIDIA GeForce 3090.
For every iteration it takes about 0.5s (including data loading, forward propagation and backward propagation). However, when comes to regularization for generator which happens every 4 iterations, the backward propagation becomes extremely slow. It takes average 25s to do the backward propagation.
Below is the code for generator regularization:
if current_iter % self.net_g_reg_every == 0: path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink']) noise = self.mixing_noise(path_batch_size, self.mixing_prob) fake_img, latents = self.net_g(noise, return_latents=True) l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length) l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0]) # TODO: why do we need to add 0 * fake_img[0, 0, 0, 0] l_g_path.backward() loss_dict['l_g_path'] = l_g_path.detach().mean() loss_dict['path_length'] = path_lengths
The loss function “l_g_path”:
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape * fake_img.shape) grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True) path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_penalty = (path_lengths - path_mean).pow(2).mean() return path_penalty, path_lengths.detach().mean(), path_mean.detach()