Model.train and requires_grad

Hi all,

I’m not getting the usage of model.train(False) and param.requires_grad(False).

It seems to me that if I execute the following:

class Test(nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        for param in self.parameters():
            print param.requires_grad


I should get False for param.requires_grad. I get True however.

So what’s the difference between model.train(False) and:

for param in self.parameters():
     param.requires_grad = False


What are the best practices when trying to initialize a model and freeze weights?


1 Like

It’s a bit confusing, but model.train(False) doesn’t change param.requires_grad. It only changes the behavior of nn.Dropout and nn.BatchNorm (and maybe a few other modules) to use the inference-mode behavior. This disables the stochastic behavior of Dropout and using the running_mean/var in BatchNorm instead of batch statistics.

If you want to freeze model weights, you should use the code snippet you wrote above:

for param in model.parameters():
     param.requires_grad = False

Depending on how you want BatchNorm to behave, you may want to call model.train(False) or call it on some sub-module of your model.


Yeah, I ended up asking a friend. Maybe the documentation needs to be clearer on this. And just to follow up, param.requires_grad=False will make the optimizer throw a fuss unless you explicitly set it to only optimize those parameters that require gradient.

I was wondering if we need to do something for the optimizer to make optimizer not change the parameter of the freezed layer? Like:

torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),,
                                                betas=(opt.beta1, 0.999))


I just have a try on requires_grad and params_groups in optimizer.

Set requires_grad=False can no longer calculate gradients of the related module and keep their grad None. Configuring optimizer can make the params don’t update in opt.step() but their gradients still calculate.