ModuleValidator.fix does not immediately change the model - how to deal with it?


I would like to use opacus with a DenseNet121. Since this neural network is by default not compatible with opacus, I need to use ModuleValidator.fix(model) before training the model.

Using the debugger, I noticed that ModuleValidator.fix(model) does not immediately change the state_dict of the model, but it is only changed later (probably after the training). Is this correct? Is there some way to change it immediately? (question 1)
Specifically, it appears that the keys in the state dict later have an “_module.” substring before their original name.

Based on this, I have two more questions:

Q2: What is the best way to load a model that was trained using opacus? Normally, I would create a densenet object and then use model.load_state_dict(…). This, however, doesn’t work when the stored state dict is from a densenet121 trained with opacus, since the key names in the state dict are different. I can also not just use ModuleValidator.fix(model) before loading the state dict, since it does not immediately change the keys.

Q3: Could I just iterate over the keys key_name of the normal densenet121 model state dict, and then replace the values by the model_fixed_through_modelvalidator.state_dict()["_module."+key_name] or would this lead to “wrong” values in the loaded model (e.g., when BatchNorm is replaced with GroupNorm)

Thanks in advance!

Hey general, thanks for reaching out! Below are answers regarding your three questions:

  • Q1: Yes, it’s indended to be used before training starts (or before checkpoint loading). Do you have a specific use case in mind that would support your suggestion (changing immediately the state dict)?
  • Q2: I would recommend that, during checkpointing when training with Opacus, you save model._module.state_dict(). This way, (1) you can load the checkpoint in a regular training loop as usual and (2) if you resume Opacus training from this checkpoint, you should call model._module.load_state_dict() after make_private.
  • Q3: See my geenric remark below.

On your specific case, I am right to assume that you could switch batch norm layers by group norm layers before applying Opacus and then proceed as usual?

Hope this helps,

Hi Pierre, Thanks for your reply. To illustrate my problem, please have a look at the following code:

import torch
from torchvision import  models
from torch import nn
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator

dataset = CIFAR10(root='data/', download=True, transform=transforms.ToTensor())
dataset =,list(range(0,200)))
train_loader =,batch_size=8, shuffle=True, num_workers=20, pin_memory=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCELoss().to(device)

running_loss = 0.0

use_differential_privacy = True

if __name__ == "__main__":

    model = models.densenet121(pretrained=True)
    num_ftrs = model.classifier.in_features
    model.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid())

        print("use differentialprivacy")
        model = ModuleValidator.fix(model)
        privacy_engine = PrivacyEngine()

    optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0,weight_decay=0)
        model, optimizer, train_loader = privacy_engine.make_private(

    model =
    print_freq = 1000
    running_loss = 0.0

    for i, data in enumerate(train_loader):
      imgs, labels = data

      batch_size = imgs.shape[0]
      imgs =
      labels =

      labels = F.one_hot(labels, num_classes=10)
      labels =

      outputs = model(imgs)

      loss = criterion(outputs, labels)

      optimizer.step()  # update weights

      running_loss += loss * batch_size
      if (i % print_freq == 0):
        print(str(i * batch_size))

    epoch_loss_train = running_loss / len(dataset)

        epsilon, best_alpha =
        print("finished local training, privacy report:",*len(train_loader))))

    print("finished training, now storing model"), "")
    model_load = models.densenet121(pretrained=True)
    num_ftrs = model_load.classifier.in_features
    model_load.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid())

    loaded_state = torch.load("")
    print("loading model")

        model_load = ModuleValidator.fix(model_load)


It results in the following error:

Exception has occurred: RuntimeError
Error(s) in loading state_dict for DenseNet:
Missing key(s) in state_dict: “features.conv0.weight”, “features.norm0.weight”, “features.norm0.bias”, “features.denseblock1.denselayer1.norm1.weight”, “features.denseblock1.denselayer1.norm1.bias”, “features.denseblock1.denselayer1.conv1.weight”, “features.denseblock1.denselayer1.norm2.weight”, “features.denseblock1.denselayer1.norm2.bias”, “features.denseblock1.denselayer1.conv2.weight”, “features.denseblock1.denselayer2.norm1.weight”,
Unexpected key(s) in state_dict: “_module.features.conv0.weight”, “_module.features.norm0.weight”, “_module.features.norm0.bias”, “_module.features.denseblock1.denselayer1.norm1.weight”, “_module.features.denseblock1.denselayer1.norm1.bias”, “_module.features.denseblock1.denselayer1.conv1.weight”, “_module.features.denseblock1.denselayer1.norm2.weight”, “_module.features.denseblock1.denselayer1.norm2.bias”,

This happens, although I use the ModuleValidator before loading the new model. Do you know how I can fix my problem?

HI general,

Thanks for providing detailed context. You can simply load the state dict after you call ModuleValidator.fix. Could you please try this?


Hi Pierre,

Thanks for your response. I am not exactly sure what you mean, I call the “load_state_dict” function in the last line, after the ModuleValidator.fix.
Or did you mean something different?

Hey general,

My bad, let me rephrase my suggestion.

  • Option 1. You can call load_state_dict() after calling privacy_engine.make_private (which you do not call here).
  • Option 2. Replace, "") by, "") in case use_differential_privacy is True.

Does that make sense?

PS: editing my comment to point out to a relevant planned enhancement.

1 Like

Hi Pierre,
thanks, this helps - I especially like option 2.