GAN Pytorch profiler

Hi all

I was looking at the new Pytorch profiler, and I am trying to learn GAN model at the same time.
Since GAN models contains two models: a Discriminator model and a Generator model.
How can I see the Pytorch profiler of the GAN model? Do I have to see the Pytorch profiler of each model alone or can I see the Pytorch profiler of the whole GAN model?

Can someone please show me how to see the Pytorch profiler of the GAN model.

Thanks a lot for your help.

The profiler would wrap the operations you would like to profile as seen in the code examples in this blog so you could profile either the models separately or the entire training iteration.

Hi @ptrblck

First of all thanks a lot for your answer.
Is it possible please to show me with a simple example how to use Pytorch profiler for the the entire training iteration of GAN because I did not know how to apply it for GAN, because it’s the first time I use Pytorch profiler and GAN.

Thanks a lot for your help.

I think you could reuse the provided code snippets from the blog post, e.g. something like this should work:

 with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=tensorboard_trace_handler,
    with_stack=True
) as profiler:
    for step, data in enumerate(trainloader, 0):
        print("step:{}".format(step))
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label,
                           dtype=real_cpu.dtype, device=device)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        profiler.step()

Note that I’ve reused the profiler code from the blog post and the GAN training from the DCGAN example.

1 Like

Hi @ptrblck

Thanks a lot for this example that I understood very well how to use correctly Pytorch profiler .