How to calculate privacy budget for multiple optimizers in a single model?

Hi,

I’m trying to figure out how can I calculate a combined epsilon-delta privacy budget for multiple optimizers when I provide some layer parameters to each optimizer. I tried to hook the same accountant to different optimizers, but I’m confused/not sure if it returns a combined privacy budget for the whole model.

I’m using torch 1.9.0 and opacus 1.0.1

Here is my code:

import torch
from torchvision import datasets, transforms
import numpy as np
from opacus import PrivacyEngine
from tqdm import tqdm
from opacus import GradSampleModule
from opacus.accountants import RDPAccountant
from opacus.optimizers import DPOptimizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_batch_size = 64
test_batch_size = 128
noise_multiplier = 1.0

train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist',
               train=True, download=True,
               transform=transforms.Compose([transforms.ToTensor(),
               transforms.Normalize((0.1307,), (0.3081,)),]),),
               batch_size=train_batch_size, shuffle=True, num_workers=1,
               pin_memory=True)

test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist',
              train=False,
              transform=transforms.Compose([transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,)),]),),
              batch_size=test_batch_size, shuffle=True, num_workers=1,
              pin_memory=True)

model = torch.nn.Sequential(torch.nn.Conv2d(1, 16, 8, 2, padding=3),
                            torch.nn.ReLU(),
                            torch.nn.MaxPool2d(2, 1), 
                            torch.nn.Conv2d(16, 32, 4, 2), 
                            torch.nn.ReLU(), 
                            torch.nn.MaxPool2d(2, 1), 
                            torch.nn.Flatten(), 
                            torch.nn.Linear(32 * 4 * 4, 32), 
                            torch.nn.ReLU(), 
                            torch.nn.Linear(32, 10))

accountant = RDPAccountant()

model = GradSampleModule(model)
model.to(device)

parameters = list(model.parameters())
params1 = parameters[:3]
params2 = parameters[3:]

optimizer1 = torch.optim.SGD(params1, lr=0.05)
optimizer2 = torch.optim.SGD(params2, lr=0.05)

optimizer1 = DPOptimizer(
    optimizer=optimizer1,
    noise_multiplier=noise_multiplier,
    max_grad_norm=1.0,
    expected_batch_size=train_batch_size
)

optimizer2 = DPOptimizer(
    optimizer=optimizer2,
    noise_multiplier=noise_multiplier,
    max_grad_norm=1.0,
    expected_batch_size=train_batch_size
)

optimizer1.attach_step_hook(
    accountant.get_optimizer_hook_fn(
        sample_rate=train_batch_size/len(train_loader.dataset)
    )
)

optimizer2.attach_step_hook(
    accountant.get_optimizer_hook_fn(
        sample_rate=train_batch_size/len(train_loader.dataset)
    )
)

def train(model, train_loader, optimizer1, optimizer2, epoch, device, delta):
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    losses = []
    for _batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)

        optimizer1.zero_grad()
        optimizer2.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        optimizer1.step()
        optimizer2.step()
        
        losses.append(loss.item())
    
    epsilon, best_alpha = accountant.get_privacy_spent(delta=delta)
        
    print(
        f"Train Epoch: {epoch} \t"
        f"Loss: {np.mean(losses):.6f} "
        f"(ε = {epsilon:.2f}, δ = {delta}) for α = {best_alpha}"
        )


def test(model, test_loader, device):
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        n_correct = 0
        for _batch_idx, (data, target) in enumerate(tqdm(test_loader)):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            # loss.backward()
            # optimizer.step()
            predicted = output.argmax(dim=1, keepdim=True)
            n_correct += predicted.eq(target.view_as(predicted)).sum().item()

    print(f'Test Acc: {100.0 * n_correct/len(test_loader.dataset)}')

for epoch in range(1, 11):
    train(model, train_loader, optimizer1, optimizer2, epoch, device=device, delta=1e-5)

test(model, test_loader, device=device)

Is the privacy budget returned only for the 2nd optimizer or all of them?

Also, AFAIK (correct me if I’m wrong), the paper “Deep Learning with Differential Privacy” defines DP-SGD and accountant only for a single optimizer. How is Opacus able to calculate a single epsilon-delta for multiple optimizers, if it is able to? On which paper is the accountant implementation based?

Hi @titan-chan and thanks for your question.

I think what you’re doing is correct and the accountant tracks the combined privacy budget from two optimizers.
The only thing accountant needs to know about for each step is noise multiplier and sampling rate. It doesn’t matter which weights are being updated and whether it’s the same set of weights throughout the training. DP-SGD works under a threat model where all gradient updates are being released independendently and doesn’t rely on anything that happens to these weights afterwards.

Therefore your setup is equivalent to doing twice the number of optimization steps, which is a valid privacy setup.

As for the theory background, we 're using Renyi Differential Privacy accounting - it outperforms moments accounting described in Abadi et al. and also enjoys privacy amplification from subsampling (see https://arxiv.org/pdf/1808.00087.pdf)

2 Likes

Hi @ffuuugor. Thanks for the answer!

Therefore your setup is equivalent to doing twice the number of optimization steps, which is a valid privacy setup.

Does it mean that doing the same number of epochs in this setup as compared to a single accountant hooked to the whole model would result in same accuracy, but greater epsilon, given that noise-multiplier in all optimizers in both approaches is exactly the same?

It sounds a bit counterintuitive. Shouldn’t epsilon be the same if the model is being trained in exactly the same way in both approaches?

Hi
I agree that it’s a bit counterintuitive. Let’s look closer into comparison between two setups:

  1. One optimizer covering all parameters
  2. Two optimizers, each covering half of the parameters

To begin with, the accuracy won’t be the same for the fixed noise_multiplier and max_grad_norm - it’s likely that scenario 2 will reach better accuracy.
The reason for that is how clipping works. When we do clipping, we compute L2 norm of the entire grad vector, and then scale it down so its L2 norm is no larger than max_grad_norm. And after that, we add Gaussian noise to each grad vector component with standard deviation = noise_multiplier * max_grad_norm. What this means is the relative amount of noise is lower for shorter grad vectors - because L2 norms of the shorter vector is smaller, and the clipping is less aggressive, i.e. each individual vector component is scaled down less, compared to scenario 1.

It is possible to get the same overall epsilon while using two optimizers: you’d need to pick a target epsilon beforehand and then call get_noise_multiplier utility function:

from opacus.accountants.utils import get_noise_multiplier

noise_multiplier = get_noise_multiplier(
    target_epsilon=target_epsilon / 2,
    ...
)
2 Likes

Got it. Looks way clearer now. Thanks!

2 Likes