How to load a pre-trainded model with set_requires_grad?

I am training a GAN, I set_requires_grad=False for Discriminator , it will stop calculating gradients for the discriminator while update the generator. when update the Discriminator, i set set_requires_grad=True back. It can save some time and memory.
but when i load the pre-trained Discriminator, it occurs error:
loaded state dict contains a parameter group that doesn’t match the size of optimizer’s group.
while if set_requires_grad is always True, this error does not occur.
so how to load a pre-trainded model with set_requires_grad?

Could you post a code snippet to reproduce this issue?
Setting the requires_grad attribute should now change the shape of any parameters, which seems to be the error here.

a code snippet is as follows:

          set_requires_grad(netD, True)  # enable backprop for D
          optimizer_D.zero_grad()  # set D's gradients to zero
          # Fake; stop backprop to the generator by detaching fake_B
          fake_AB =, fake_B), 1)
          pred_fake = netD(fake_AB.detach())  
          loss_D_fake = criterionGAN(pred_fake, False) 
          # Real
          real_AB =, real_B), 1)  
          pred_real = netD(real_AB)
          loss_D_real = criterionGAN(pred_real, True)
          # combine loss and calculate gradients
          loss_D = (loss_D_fake + loss_D_real) * 0.5  # 
          # netD_weight.zero_grad()

          """Calculate GAN and L1 loss for the generator"""
          # First, G(A) should fake the discriminator
          set_requires_grad(netD, False)   # D requires no gradients when optimizing G
          fake_AB =, fake_B), 1)
          pred_fake = netD(fake_AB)  # 
          loss_G_GAN = criterionGAN(pred_fake, True)  * 1e-2
          loss = loss_l1 + loss_G_GAN 
the model save and load function are as follow:
def save_ckpt(ckpt_name, models, optimizers, n_iter):
    ckpt_dict = {'n_iter': n_iter}
    for prefix, model in models:
        ckpt_dict[prefix] = get_state_dict_on_cpu(model)

    for prefix, optimizer in optimizers:
        ckpt_dict[prefix] = optimizer.state_dict(), ckpt_name)

def load_ckpt(ckpt_name, models, optimizers=None):
    ckpt_dict = torch.load(ckpt_name)
    for prefix, model in models:
        assert isinstance(model, nn.Module)
        model.load_state_dict(ckpt_dict[prefix], strict=False)
    if optimizers is not None:
        for prefix, optimizer in optimizers:
    return ckpt_dict['n_iter']

error occurs when i load the pre-trained model:
load_ckpt(args.resumeD, [(‘model’, netD)], [(‘optimizer’, optimizer_D)])

Thanks for the code.
Unfortunately, set_requires_grad, criterionGAN, the model etc. are all undefined, so that debugging is not possible.
Could you post a (small) executable code snippet to reproduce this error?

‘set_requires_grad’, ‘criterionGAN’ are defined as :

criterionGAN = GANLoss(args.gan_mode).to(device)
where GANLoss is :
class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

def set_requires_grad(nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

all code is a little long, but the core is mainly the above.

I cannot reproduce the error using the DCGAN Discriminator and your set_requires_grad method:

nc = 3
ndf = 4
netD = Discriminator(1)

state_dict = copy.deepcopy(netD.state_dict())

set_requires_grad(netD, True)  # enable backprop for D
> <All keys matched successfully>

set_requires_grad(netD, False)
> <All keys matched successfully>