Improved WGAN implementation slower than tensorflow

Hello, I’ve implemented Improved WGAN myself looking at code from https://github.com/caogang/wgan-gp,
but it is about two times slower than the one implemented in tensorflow, which is also my code.

Both codes are implemented to use DCGAN + Improved WGAN,
and it takes about 10 seconds in tensorflow and 20 seconds in pytorch for 100 training iterations.

I’ve tested a few things, and by changing the loss function to LSGAN,
training time reduced to 9 seconds for both versions,
so I’m guessing the problem is on WGAN part.

Here’s my code, and I can’t find the problem,
so I would appreciate if anyone could help me find the problem.

                lr = 0.0002

                disc = Discriminator()
                gen = Generator()

                opt_gen = optim.Adam(gen.parameters(), lr)
                opt_disc = optim.Adam(disc.parameters(), lr)

                for _ in range(3):
                    disc.zero_grad()

                    # Real & fake x
                    batch_data = torch.tensor(dataloader.get_batch(num_batch), dtype=torch.float32)

                    z_val, cat_input = generate_z_val(num_batch, num_z, num_cat)
                    x_gen = gen(z_val)

                    # Disc
                    disc_real, _ = disc(batch_data)
                    disc_fake, cat_output = disc(x_gen)

                    disc_real = disc_real.mean()
                    disc_fake = disc_fake.mean()

                    # Improved WGAN
                    eps = torch.rand((num_batch, 1)).expand(batch_data.size())
                    scale_fn = 10

                    x_pn = eps * batch_data + (1 - eps) * x_gen
                    disc_pn, _ = disc(x_pn)

                    grad = \
                        autograd.grad(disc_pn, x_pn, grad_outputs=wgan_grad_output, create_graph=True,
                                      retain_graph=True)[0]
                    grad = grad.norm(dim=1)

                    ddx = scale_fn * (grad - 1) ** 2
                    ddx = ddx.mean()

                    loss_real = disc_real - disc_fake + ddx

                    loss_real.backward()

                    opt_disc.step()

                # Generator train
                for param in disc.parameters():
                    param.requires_grad_(False)

                gen.zero_grad()

                z_val, cat_input = generate_z_val(num_batch, num_z, num_cat)
                x_gen = gen(z_val)
                disc_fake, cat_output = disc(x_gen)

                disc_fake = disc_fake.mean()

                loss_fake = disc_fake

                loss_fake.backward()

                opt_gen.step()

If you would like to see the full code,

Pytorch version:

Tensorflow versio:

codes are here.

Thanks.