GAN raises "UserWarning: Using a non-full backward hook when the forward contains multiple..."

Hello everyone,

I am trying to learn how to use differential private training for a GAN. As a start, I implemented PyTorch GAN in a slightly modified version such that I can train the generator with or without DP. Without DP the model works perfectly fine, however, as soon as I use the PrivacyEngine, I get the following warning:

UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.

Could one of you point me in a direction to what I am doing wrong? During debugging I found that the warning is triggered by the line gen_imgs = self.gen(z) (line 138), but I don’t understand why.

Here is the full example code. DP training can be switched on and off by setting the dp=True/False when initialising the Combined model (see line 190).

"""Example for non-full backward hook warning"""
import logging
import os
import warnings

import numpy as np
import torch
from torchvision import datasets
from opacus import PrivacyEngine
from torch import nn
import torchvision.transforms as transforms

GPU = 0
channels = 1
img_size = 28
latent_dim = 100
img_shape = (channels, img_size, img_size)
delta = 1e-5
log = logging.getLogger()


class Generator(nn.Module):
    def __init__(self, dp: bool = False):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize and not dp:
                # Normalization not supported for dp models
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, dp: bool = False):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


class Combined(nn.Module):

    def __init__(self, gpu: int, dp: bool = False):
        # Initialize generator and discriminator
        super(Combined, self).__init__()
        self.gen = Generator(dp=dp)
        self.dis = Discriminator(dp=dp)
        self.dp = dp

        # Loss function
        self.adversarial_loss = torch.nn.BCELoss()

        self.device = f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu'
        log.info(f"Using GPU {gpu}")
        if torch.cuda.is_available():
            self.gen.to(device=self.device)
            self.dis.cuda(device=self.device)
            self.adversarial_loss.cuda(device=self.device)

        if dp:
            # Validate if Model can use DP-SGD
            from opacus.validators import ModuleValidator
            errors = ModuleValidator.validate(self, strict=False)
            if len(errors) == 0:
                log.info("No errors found - model can be trained with DP.")
            else:
                raise RuntimeError("DP training not possible.")

        # Optimizers
        self.opt_g = torch.optim.Adam(self.gen.parameters(), lr=0.0002)
        self.opt_d = torch.optim.Adam(self.dis.parameters(), lr=0.0002)

    def training_loop(self, dataloader, n_epochs):

        # DP initialization
        if self.dp:
            warnings.simplefilter("ignore")
            self.privacy_engine = PrivacyEngine()

            self.gen, self.opt_g, dataloader = self.privacy_engine.make_private_with_epsilon(
                module=self.gen,
                optimizer=self.opt_g,
                data_loader=dataloader,
                epochs=n_epochs,
                target_epsilon=50.0,
                target_delta=delta,
                max_grad_norm=1.2,
            )
            warnings.simplefilter("default")

        for epoch in range(1, n_epochs + 1):
            for i, (imgs, _) in enumerate(dataloader):
                valid = torch.ones((imgs.size(0), 1), device=self.device)
                fake = torch.zeros((imgs.size(0), 1), device=self.device)

                # Configure input
                real_imgs = imgs.to(device=self.device)

                # -----------------
                #  Train Generator
                # -----------------
                self.gen.train()
                self.opt_g.zero_grad()

                # Sample noise as generator input
                # z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
                z = torch.randn(size=(imgs.shape[0], latent_dim), device=self.device)

                # Generate a batch of images
                gen_imgs = self.gen(z)

                # Loss measures generator's ability to fool the discriminator
                g_loss = self.adversarial_loss(self.dis(gen_imgs), valid)

                g_loss.backward()
                self.opt_g.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------
                self.dis.train()
                self.opt_d.zero_grad()

                # Measure discriminator's ability to classify real from generated samples
                real_loss = self.adversarial_loss(self.dis(real_imgs), valid)
                fake_loss = self.adversarial_loss(self.dis(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                self.opt_d.step()

                if i % 100 == 0:
                    msg = (
                        f"[Epoch {epoch:03d}/{n_epochs}] [Batch {i:03d}/{len(dataloader)}] "
                        f"[D loss: {d_loss.item():.5f}] [G loss: {g_loss.item():.5f}] "
                    )
                    if self.dp and (epoch > 1 or i > 0):
                        eps = self.privacy_engine.get_epsilon(delta)
                        msg += f"(ε = {eps:.2f}, δ = {delta}) "
                    print(msg)


def load_data(batch_size: int):
    os.makedirs(f"tmp/data/mnist", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            f"tmp/data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    return dataloader


if __name__ == '__main__':
    gan = Combined(GPU, dp=True)
    dataloader = load_data(256)
    gan.training_loop(dataloader, 10)

I don’t know how I could miss this before asking the question, but there is already an issue on Github addressing this warning: https://github.com/pytorch/pytorch/issues/598.

For future people finding this:

According to one contributor’s comment, the warning can be safely ignored.

Moreover, it can be avoided with the following modification of one’s code: Comment by karthikprasad taken from GitHub

That said, starting Opacus 1.2.0, we support functorch based per-sample gradient computation (no hooks, no warnings). To use this, simply set grad_sample_mode="functorch" in the call to make_private(). You can find more details about this in the release notes.
Note: as mentioned in the release notes, functorch support is still in beta mode, and it could be slower than hooks.

Update:
I tried both grad_sample_mode="functorch" and grad_sample_mode="ew", however, in both cases the same warning is still present.

Here is an even shorter example with the generator only:

import logging
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms.functional as F
from tqdm import tqdm

channels = 1
img_size = 28
img_shape = (channels, img_size, img_size)

def load_data(batch_size: int) -> DataLoader:
    # Configure data loader
    os.makedirs(f"tmp/data/mnist", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            f"tmp/data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    return dataloader

class Generator(nn.Module):
    def __init__(self, dp: bool = False):
        super(Generator, self).__init__()
        self.model = nn.Linear(100, int(np.prod(img_shape)))

    def forward(self, z):
        img = self.model(z).tanh()
        img = img.view(img.size(0), *img_shape)
        return img


m = Generator(True)
m = m.to('cuda')
opt = torch.optim.Adam(m.parameters())
dataloader = load_data(128)
privacy_engine = PrivacyEngine(accountant='rdp')


m, opt, dataloader = privacy_engine.make_private_with_epsilon(
    module=m,
    optimizer=opt,
    data_loader=dataloader,
    epochs=10,
    target_epsilon=10,
    target_delta=1e-5,
    max_grad_norm=1.0,
    poisson_sampling=True,
    grad_sample_mode="functorch"  # ew
)

loss = torch.nn.BCELoss()
for epoch in tqdm(range(1, 10 + 1), leave=False, ncols=80):
    for i, (imgs, _) in enumerate(dataloader):
        valid = torch.ones((imgs.size(0), 1), device='cuda:0')
        fake = torch.zeros((imgs.size(0), 1), device='cuda:0')

        # Configure input
        real_imgs = imgs.to('cuda')

        # -----------------
        #  Train Generator
        # -----------------
        m.train()
        opt.zero_grad()

        # Sample noise as generator input
        z = torch.randn(size=(imgs.shape[0], 100), device='cuda')

        # Generate a batch of images
        gen_imgs = m(z)
        break
    break

I see a similar warning with my DPGAN implementation and came here to ask the same question!

1 Like