Improved WGAN implementation slower than tensorflow

Hello, I’ve implemented Improved WGAN myself looking at code from,
but it is about two times slower than the one implemented in tensorflow, which is also my code.

Both codes are implemented to use DCGAN + Improved WGAN,
and it takes about 10 seconds in tensorflow and 20 seconds in pytorch for 100 training iterations.

I’ve tested a few things, and by changing the loss function to LSGAN,
training time reduced to 9 seconds for both versions,
so I’m guessing the problem is on WGAN part.

Here’s my code, and I can’t find the problem,
so I would appreciate if anyone could help me find the problem.

                lr = 0.0002

                disc = Discriminator()
                gen = Generator()

                opt_gen = optim.Adam(gen.parameters(), lr)
                opt_disc = optim.Adam(disc.parameters(), lr)

                for _ in range(3):

                    # Real & fake x
                    batch_data = torch.tensor(dataloader.get_batch(num_batch), dtype=torch.float32)

                    z_val, cat_input = generate_z_val(num_batch, num_z, num_cat)
                    x_gen = gen(z_val)

                    # Disc
                    disc_real, _ = disc(batch_data)
                    disc_fake, cat_output = disc(x_gen)

                    disc_real = disc_real.mean()
                    disc_fake = disc_fake.mean()

                    # Improved WGAN
                    eps = torch.rand((num_batch, 1)).expand(batch_data.size())
                    scale_fn = 10

                    x_pn = eps * batch_data + (1 - eps) * x_gen
                    disc_pn, _ = disc(x_pn)

                    grad = \
                        autograd.grad(disc_pn, x_pn, grad_outputs=wgan_grad_output, create_graph=True,
                    grad = grad.norm(dim=1)

                    ddx = scale_fn * (grad - 1) ** 2
                    ddx = ddx.mean()

                    loss_real = disc_real - disc_fake + ddx



                # Generator train
                for param in disc.parameters():


                z_val, cat_input = generate_z_val(num_batch, num_z, num_cat)
                x_gen = gen(z_val)
                disc_fake, cat_output = disc(x_gen)

                disc_fake = disc_fake.mean()

                loss_fake = disc_fake



If you would like to see the full code,

Pytorch version:

Tensorflow versio:

codes are here.
