Error when trying Federated Learning with Opacus

Hi,

I have a simple federated learning setup, where I train a Densenet-121 on subsets of the CIFAR-10 dataset. By setting a variable to False or True, I can control whether it does not use Opacus/Differential privacy or whether it does.
I implemented the aggregation (averaging) of the models myself in the code.

However, when I use differential privacy, I get the following error and I don’t gain much information from it, what the actual problem and solution is:

Exception has occurred: AttributeError
‘Parameter’ object has no attribute ‘_forward_counter’
File “/home/user/diffp1/FederatedLearningClient.py”, line 62, in train
outputs = model(imgs)
File “/home/user/diffp1/FederatedLearningCoalition.py”, line 33, in run_communication_round
client.train(self.global_model, self.device)
File “/home/user/diffp1/main5.py”, line 9, in
fc.run_communication_round()

My code consists of 3 files, one main file and two classes.
main5.py, where I can the True to False when I do not want to use differential privacy (in this case the code runs without any problem)

from FederatedLearningCoalition import *

fc = FederatedLearningCoalition(True)

if __name__ == "__main__":

    for i in range(20):
        print("communication round", i, "starting")
        fc.run_communication_round()
        print("communication round", i, "finished")

    print("entirely finished")

FederatedLearningCoalition.py

from torchvision import  models
from torch import nn
import copy
from shutil import copyfile
from os import listdir
from os.path import isfile, join
from opacus.validators import ModuleValidator
import torch
from FederatedLearningClient import *

class FederatedLearningCoalition:
    
    def __init__(self, use_differential_privacy=True):
        a = 0
        self.use_differential_privacy = use_differential_privacy

        self.client_list = []
        self.client_list.append(FederatedLearningClient(use_differential_privacy, 0, 300, 400))
        self.client_list.append(FederatedLearningClient(use_differential_privacy, 400, 700, 800))
        self.client_list.append(FederatedLearningClient(use_differential_privacy, 800, 1100, 1200))

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.global_model = models.densenet121(pretrained=True)
        num_ftrs_model = self.global_model.classifier.in_features
        self.global_model.classifier = nn.Sequential(nn.Linear(num_ftrs_model, 10), nn.Sigmoid())
        
        if(self.use_differential_privacy):
            self.global_model = ModuleValidator.fix(self.global_model)

    def run_communication_round(self):
        for client in self.client_list:
            client.train(self.global_model, self.device)
            print("training for a client finished")
            
        client_model_list = [item.model for item in self.client_list]
        self.global_model = self.aggregate(client_model_list)

        for client in self.client_list:
            client.validate(self.global_model, self.device)
            print("validation for a client finished")

    def aggregate(self, list_of_local_models):
        
        global_dict = list_of_local_models[0].state_dict()

        for k in global_dict.keys():
            global_dict[k] = torch.stack([list_of_local_models[i].state_dict()[k].float() for i in range(len(list_of_local_models))], 0).mean(0)

        return_model = copy.deepcopy(list_of_local_models[0])
        return_model.load_state_dict(global_dict)
        return return_model

FederatedLearningClient.py

import numpy as np
from torch import nn
import torchvision.transforms as transforms
import copy
from shutil import copyfile
from datetime import date
from os import listdir
from os.path import isfile, join
from opacus.validators import ModuleValidator
from opacus import PrivacyEngine
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
import torch
import torch.nn.functional as F

class FederatedLearningClient:
    
    def __init__(self, differential_privacy=True, cifar_start_index=0, cifar_train_end_index=20, cifar_val_end_index=24):

        print("differential_privacy", str(differential_privacy))
        self.differential_privacy = differential_privacy
        if(differential_privacy):
            self.privacy_engine = PrivacyEngine()
        dataset = CIFAR10(root='data/', download=True, transform=transforms.ToTensor())
        self.dataset_train = torch.utils.data.Subset(dataset, range(cifar_start_index, cifar_train_end_index))
        self.dataset_val = torch.utils.data.Subset(dataset, range(cifar_train_end_index+1, cifar_val_end_index))

    def train(self, model_passed, device):
        model = copy.deepcopy(model_passed)
        train_loader = torch.utils.data.DataLoader(self.dataset_train,batch_size=8, shuffle=True, num_workers=4, pin_memory=True)

        if(self.differential_privacy):
            model = ModuleValidator.fix(model)

        optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0,weight_decay=0)
            
        if(self.differential_privacy):
            model, optimizer, train_loader = self.privacy_engine.make_private(module=model,optimizer=optimizer,data_loader=train_loader,noise_multiplier=0.1,max_grad_norm=1.0)
            
        model = model.to(device)
        
        print_freq = 2000
        running_loss = 0.0

        for i, data in enumerate(train_loader):
            imgs, labels = data

            batch_size = imgs.shape[0]
            imgs = imgs.to(device)
            labels = labels.to(device)

            try:
                labels = F.one_hot(labels.to(torch.int64), num_classes=10)
            except Exception as e:
                print("labels", labels)
                print("e", e)
            labels = labels.to(torch.float32)
            
            if(batch_size>0):
                optimizer.zero_grad()
                model.train()
                outputs = model(imgs)

                criterion = nn.BCELoss().to(device)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()  # update weights

                running_loss += loss * batch_size
                if (i % print_freq == 0):
                    print(str(i * batch_size))
            else:
                print("batch size not larger than 0")

        epoch_loss_train = running_loss / len(self.dataset_train)

        if(self.differential_privacy):
            epsilon, best_alpha = self.privacy_engine.accountant.get_privacy_spent(
                delta=1/(10*len(self.dataset_train))
            )
            print("finished local training, train loss", epoch_loss_train, "privacy report:", self.privacy_engine.accountant.get_privacy_spent(delta=1/(10*len(train_loader))))

        self.model = model
        print("finished")

    def validate(self, model_passed, device):
        model = copy.deepcopy(model_passed)
        model.eval()
        val_loader = torch.utils.data.DataLoader(self.dataset_val,batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
        model = model.to(device)
        
        running_loss = 0.0

        criterion = nn.BCELoss().to(device)

        for i, data in enumerate(val_loader):
            imgs, labels = data

            batch_size = imgs.shape[0]
            imgs = imgs.to(device)
            labels = labels.to(device)

            try:
                labels = F.one_hot(labels.to(torch.int64), num_classes=10)
            except Exception as e:
                print("labels", labels)
                print("e", e)
            labels = labels.to(torch.float32)
            
            if(batch_size>0):
                
                with torch.no_grad():
                    outputs = model(imgs)

                loss = criterion(outputs, labels)

                running_loss += loss * batch_size
            else:
                print("batch size not larger than 0")

        epoch_loss_val = running_loss / len(self.dataset_train)

        self.model = model
        print("finished validating, loss", epoch_loss_val)

Do you know what causes the error and how to fix it. After downloading the CIFAR-10 dataset, it should not take long to run the code, since it only trains on a small subset of the actual dataset.
Thank you in advance!

Hi @general and thanks for your question.
The reason for the issue you’re running into is that you directly manipulate model weights during model aggregation stage, which GradSampleModule doesn’t expect (it uses certain private attributed attached to parameters and doesn’t know what to do when they’re missing)

I can suggest one way to work around this - only use GradSampleModule during the training process, while storing and aggregating the original unwrapped model.

As you call privacy_engine.make_private() on every round, this shouldn’t be a problem for you. You can simply do self.model = model.to_standard_module() at the end of the client training loop (instead of self.model = model you have now), it should do the trick. This was you’ll store the original model, and GradSampleModule will initialize all the required attributes at the beginning of each round as necessary.

Hope this helps

Hi, thanks for the tips. The to_standard_module() works perfectly. I have another question: I call privacy_engine.make_private_with_epsilon() at the very beginning of FL training, and only once (because the epsilon is set, I cannot call it in each round). Now the problem is, the model is GradSampleModule for the first round, then converted to the standard module for parameter aggregation, once the client’s model parameters are updated, how can I convert it back to GradSampleModule for the second round training? Is there a method like model = GradSampleModule(model), which will be called for the following rounds?

You can indeed do model = GradSampleModule(model). Another possibility is to change the parameters of the model directly.

Thanks for the tips! I finally arrived at the codes like this: 1. The client calls privacy_engine.make_private_with_epsilon() before the training starts, and immediately converts the model to a standard one using model.to_standard_module(); 2. The client loads the server’s global model parameter to the standard model. 3. When training starts, convert the model back to GradSampleModule; 4. When training finishes, convert the model to a standard one again. Will these format changes affect the FL training process itself? I hope not since GradSampleModule is just a wrapped object of nn.module.

1 Like