Hello everyone,
I am trying to learn how to use differential private training for a GAN. As a start, I implemented PyTorch GAN in a slightly modified version such that I can train the generator with or without DP. Without DP the model works perfectly fine, however, as soon as I use the PrivacyEngine, I get the following warning:
UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
Could one of you point me in a direction to what I am doing wrong? During debugging I found that the warning is triggered by the line gen_imgs = self.gen(z)
(line 138), but I don’t understand why.
Here is the full example code. DP training can be switched on and off by setting the dp=True/False
when initialising the Combined model (see line 190).
"""Example for non-full backward hook warning"""
import logging
import os
import warnings
import numpy as np
import torch
from torchvision import datasets
from opacus import PrivacyEngine
from torch import nn
import torchvision.transforms as transforms
GPU = 0
channels = 1
img_size = 28
latent_dim = 100
img_shape = (channels, img_size, img_size)
delta = 1e-5
log = logging.getLogger()
class Generator(nn.Module):
def __init__(self, dp: bool = False):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize and not dp:
# Normalization not supported for dp models
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
class Discriminator(nn.Module):
def __init__(self, dp: bool = False):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
class Combined(nn.Module):
def __init__(self, gpu: int, dp: bool = False):
# Initialize generator and discriminator
super(Combined, self).__init__()
self.gen = Generator(dp=dp)
self.dis = Discriminator(dp=dp)
self.dp = dp
# Loss function
self.adversarial_loss = torch.nn.BCELoss()
self.device = f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu'
log.info(f"Using GPU {gpu}")
if torch.cuda.is_available():
self.gen.to(device=self.device)
self.dis.cuda(device=self.device)
self.adversarial_loss.cuda(device=self.device)
if dp:
# Validate if Model can use DP-SGD
from opacus.validators import ModuleValidator
errors = ModuleValidator.validate(self, strict=False)
if len(errors) == 0:
log.info("No errors found - model can be trained with DP.")
else:
raise RuntimeError("DP training not possible.")
# Optimizers
self.opt_g = torch.optim.Adam(self.gen.parameters(), lr=0.0002)
self.opt_d = torch.optim.Adam(self.dis.parameters(), lr=0.0002)
def training_loop(self, dataloader, n_epochs):
# DP initialization
if self.dp:
warnings.simplefilter("ignore")
self.privacy_engine = PrivacyEngine()
self.gen, self.opt_g, dataloader = self.privacy_engine.make_private_with_epsilon(
module=self.gen,
optimizer=self.opt_g,
data_loader=dataloader,
epochs=n_epochs,
target_epsilon=50.0,
target_delta=delta,
max_grad_norm=1.2,
)
warnings.simplefilter("default")
for epoch in range(1, n_epochs + 1):
for i, (imgs, _) in enumerate(dataloader):
valid = torch.ones((imgs.size(0), 1), device=self.device)
fake = torch.zeros((imgs.size(0), 1), device=self.device)
# Configure input
real_imgs = imgs.to(device=self.device)
# -----------------
# Train Generator
# -----------------
self.gen.train()
self.opt_g.zero_grad()
# Sample noise as generator input
# z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
z = torch.randn(size=(imgs.shape[0], latent_dim), device=self.device)
# Generate a batch of images
gen_imgs = self.gen(z)
# Loss measures generator's ability to fool the discriminator
g_loss = self.adversarial_loss(self.dis(gen_imgs), valid)
g_loss.backward()
self.opt_g.step()
# ---------------------
# Train Discriminator
# ---------------------
self.dis.train()
self.opt_d.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = self.adversarial_loss(self.dis(real_imgs), valid)
fake_loss = self.adversarial_loss(self.dis(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
self.opt_d.step()
if i % 100 == 0:
msg = (
f"[Epoch {epoch:03d}/{n_epochs}] [Batch {i:03d}/{len(dataloader)}] "
f"[D loss: {d_loss.item():.5f}] [G loss: {g_loss.item():.5f}] "
)
if self.dp and (epoch > 1 or i > 0):
eps = self.privacy_engine.get_epsilon(delta)
msg += f"(ε = {eps:.2f}, δ = {delta}) "
print(msg)
def load_data(batch_size: int):
os.makedirs(f"tmp/data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
f"tmp/data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=batch_size,
shuffle=True,
)
return dataloader
if __name__ == '__main__':
gan = Combined(GPU, dp=True)
dataloader = load_data(256)
gan.training_loop(dataloader, 10)