Using Opacus in federated learning

Hi, I’m looking for how to combine federated learning with Opacus. The [tutoriasl ] Opacus · Train PyTorch models with Differential Privacy) provide is no longer available. Can someone share some example?
Also, I’m thinking, how to determine the privacy budget in FL, like, do we need to specify some amount of epsilon as the parameters of method privacy_engine.make_private_with_epsilon in for each client in each federated learning round? Or we just use the method privacy_engine and calculate the epsilon as epsilon = privacy_engine.get_epsilon(delta=self.DELTA) instead?

Another question is when to use the method, like everytime a client is selected in current round, we built a new privacy_engine for that client?

Thanks for help

Hi Zark,

In FL it depends whether you want to do user-level privacy or sample-level privacy.

Regarding the budget, the epsilon of the privacy engine accounts for all training steps. So typically you should use the same privacy_engine throughout all rounds.

Hi Alex,

Thanks for your information. I tried to use the same privacy_engine throughout all rounds but somehow get an error after the first global round training in FL: Trying to add hooks twice to the same model. So currently, I’m building a new privacy_engine for clients participated in each global training round.

I’m not sure I understand the use of Opacus right, could you help me justify the following two statements:

privacy_engine.make_private_with_epsilon with epoch = n as the argument will consume (approximately) the given target_epsilon within n epochs. Let’s say, if I set target_epsilon = 0.5 and epoch = 10, after 10 epochs, privacy_engine.get_epsilon(delta=self.DELTA) returns value around 0.5.

While in the other way, privacy_engine.make_private with noise_multiplier as the argument will inject a fixed level of Gaussian noise (by default I guess?) into the model gradients to achieve DP. In this way, privacy_engine.get_epsilon(delta=self.DELTA) will calculate the privacy cost for each epoch. By recording this returned value, we can get a list of privacy costs: (ε_1,ε_2,…,ε_n) for n epochs. And the total privacy cost is the sum of the list (ε_1,ε_2,…,ε_n)

privacy_engine.make_private_with_epsilon with epoch = n as the argument will consume (approximately) the given target_epsilon within n epochs. Let’s say, if I set target_epsilon = 0.5 and epoch = 10 , after 10 epochs, privacy_engine.get_epsilon(delta=self.DELTA) returns value around 0.5.

That’s correct

we can get a list of privacy costs: (ε_1,ε_2,…,ε_n) for n epochs. And the total privacy cost is the sum of the list(ε_1,ε_2,…,ε_n)

It is possible to something along these lines but the formula for the global epsilon is not the sum of epsilons, you have to use the composition formula. But the goal of RDP accounting is to do exactly that for you.

If you can share a Minimal Reproducing Example, we can help you with the error " Trying to add hooks twice to the same model ." This error should not happen for your problem.

In the meantime, a way to solve your problem of accounting might be to use the make_private() function at every round, and choose the noise multiplier using the get_noise_multiplier() function from https://github.com/pytorch/opacus/blob/main/opacus/accountants/utils.py#L23.
This way, the noise multiplier will be computer to ensure that the total privacy epsilon over all epochs is set to your target_epsilon.

Hi Alex,

I think I have figured out how to use Opacus in federated learning. Thanks for your helpful information.

I create a client class, I attached a privacy_engine to each client, and at each training round of FL, I just load the model state from the global model, keeping the privacy_engine untouched

class client(object):
    def __init__(self, args, global_model, train_loader, .test_loader ):
        self.args = args
        self.eps =  {}
        self.delta =  1 / (1.1*len(train_set))
        # copy model parameters from global_model
        model = copy.deepcopy(global_model)
        # create privacy_engine 
        self.privacy_engine = PrivacyEngine()
        self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private(
            module=model,
            optimizer=torch.optim.SGD(model.parameters(), self.args.lr, momentum=0.5),
            data_loader=self.train_loader,
            noise_multiplier =self.args.noise_multiplier, 
            max_grad_norm=self.args.max_grad_norm) 


    def update_model(self, global_round, model_weights):
        self.model.load_state_dict(model_weights)
        self.model.train()
        with BatchMemoryManager(data_loader=self.train_loader,
                                max_physical_batch_size=self.args.max_physical_batch_size,
                                optimizer=self.optimizer) as memory_safe_data_loader:            
            train_results = self.train(self.model, self.optimizer, memory_safe_data_loader)
        #record privacy cost spend
        self.eps[global_round] = self.privacy_engine.get_epsilon(delta=self.delta)
        return train_results

    def train(self, model, optimizer, dataloader):
        epoch_loss, epoch_acc = [], []
        for epoch in range(self.args.local_ep):
            batch_loss, batch_acc = [], []
            for batch_idx, (images, target) in enumerate(dataloader):
                optimizer.zero_grad()
                images, target = images.to(self.device), target.to(self.device)
                output = model(images)
                loss = nn.NLLLoss(output, target)
                loss.backward()
                optimizer.step()
        return model.state_dict()

In the main function, I first instantiate clients, so each client has their own privacy_engine

if __name__ == '__main__':  
    <.... some other codes ...>
    # instantiate clients
    client_lst = []                                                                         
    for cid in range(args.num_clients):                                                     
        client_lst.append(client(cid, args, global_model, train_loader,test_loader))           
             
    # the server selects clients to collaboratively train model  
    for epoch in range(args.epochs):                                              
        local_weights_lst = []                 
        m = max(int(args.frac * args.num_clients), 1)                                   
        selected_clients = np.random.choice(range(args.num_clients), m, replace=False)  
        
       for cid in selected_clients:                                                    
            local_weights = client_lst[cid].update_model(                                
                epoch, global_weights, args.noise_multiplier)              
            local_weights_lst.append(copy.deepcopy(local_weights))   
                                                                                    
       # aggregate local model weights to get global model weights                     
       global_weights = average_weights(local_weights_lst)                             

And the code runs well with no errors. In regards to your suggestion of get_noise_multiplier(), I think it is worth trying. The original question is solved and thanks again for your help.

Actually, there is another question that surfaces: I see from another question How to store the state and resume the state of the PrivacyEngine? - #2 by ffuuugor which is about store and resume the state of privacy_engine, I’m wondering a such question:

If I first set noise_multipler = 1.1, run the training with DP-SGD and stored the state of privacy_engine. Then resumed the state of privacy_engine and additionally, set noise_multipler = 1.2 (another value different from the previous choice) and continue training with DP-SGD. Will Opacus be able to calculate the privacy cost spend so far ?

From the Algorithm 1 of the original paper Deep Learning with Differential Privacy, I think they fixed the level of noise_multipler, so I assume the answer is no?