Solution for slow backward propagation

I’m trying to train a StyleGAN2 model using BasicSR library.
My setup is torch1.8 with 4 NVIDIA GeForce 3090.

For every iteration it takes about 0.5s (including data loading, forward propagation and backward propagation). However, when comes to regularization for generator which happens every 4 iterations, the backward propagation becomes extremely slow. It takes average 25s to do the backward propagation.

Below is the code for generator regularization:

if current_iter % self.net_g_reg_every == 0:
            path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink'])
            noise = self.mixing_noise(path_batch_size, self.mixing_prob)
            fake_img, latents = self.net_g(noise, return_latents=True)
            l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)

            l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
            # TODO:  why do we need to add 0 * fake_img[0, 0, 0, 0]
            l_g_path.backward()
            loss_dict['l_g_path'] = l_g_path.detach().mean()
            loss_dict['path_length'] = path_lengths

The loss function “l_g_path”:

def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
    grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

    path_penalty = (path_lengths - path_mean).pow(2).mean()

    return path_penalty, path_lengths.detach().mean(), path_mean.detach()