Requires_grad doesn't propagate to the parameters of the module

hi!

I am trying to have different layers of the network trained during specific epochs. So I set require_grad for those layers to True and then False. I noticed that :
self.model.conv2[0].requires_grad is True
self.model.conv2[0].weight.requires_grad is False
although I set self.model.conv2[0].requires_grad = False
where self.model.conv2[0] is Conv2d(12, 25, kernel_size=(3, 3), stride=(1, 1))

shouldn’t require_grad propagate to the parameters of the layers (weights and biases)?

thank you!

nn.Modules don’t have a requires_grad field. So creating one and setting it to True won’t change anything.
You can use zero_grad() on it to set all gradients to zero. Or use parameters() to get an iterator over the parameters of your module so that you can set requires_grad to True for each of them.

thank you for the quick reply. that’s interesting.
maybe the require_grad should be removed from the nn.Modules since it’s a bit misleading.

nn.Modules do not have a requires_grad field… It think you are creating it yourself.

yes! you’re right.
thank you

Sorry to revive an old topic. But I have kind of the same question:

Is there an easy way to set requires_grad simultaneously to all the parameters associated with a module (recursively if said module contains other modules)?

You could write a method, which accepts a module, checks for valid parameters (weight and bias) and manipulates the requires_grad attribute, similar to a weight_init method:

def set_requires_grad(m, requires_grad):
    if hasattr(m, 'weight') and m.weight is not None:
        m.weight.requires_grad_(requires_grad)
    if hasattr(m, 'bias') and m.bias is not None:
        m.bias.requires_grad_(requires_grad)


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.module1 = nn.Sequential(
            nn.Linear(1, 1),
            nn.ReLU(),
            nn.Linear(1,1, bias=False)
        )
        self.module2 = nn.Sequential(
            nn.Linear(1, 1),
            nn.ReLU(),
            nn.Linear(1,1, bias=False)
        )
    

model = MyModel()
model.module1.apply(lambda m: set_requires_grad(m, False))
print(model.module1[0].weight.requires_grad)
> False
print(model.module1[0].bias.requires_grad)
> False
print(model.module2[2].weight.requires_grad)
> True
1 Like

Is there a more generic way, that doesn’t assume that all parameters are called “weight” or “bias”? Thanks!

def set_requires_grad(m, requires_grad):
    for param in m.parameters():
        param.requires_grad_(requires_grad)

should work as well.

4 Likes

model.parameters() and model.named_parameters() will return all parameters recursively from all submodules as seen here:

class SubSubModule(nn.Module):
    def __init__(self):
        super(SubSubModule, self).__init__()
        self.sub_sub_param = nn.Parameter(torch.randn(1))

class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.sub_param = nn.Parameter(torch.randn(1))
        self.sub_sub_module = SubSubModule()


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.param = nn.Parameter(torch.randn(1))
        self.sub_module = SubModule()
        
model = MyModel()
for name, param in model.named_parameters():
    print(name, param)
> param Parameter containing:
tensor([1.8495], requires_grad=True)
sub_module.sub_param Parameter containing:
tensor([1.8534], requires_grad=True)
sub_module.sub_sub_module.sub_sub_param Parameter containing:
tensor([0.0247], requires_grad=True)


for param in model.parameters():
    print(param)
> Parameter containing:
tensor([1.8495], requires_grad=True)
Parameter containing:
tensor([1.8534], requires_grad=True)
Parameter containing:
tensor([0.0247], requires_grad=True)

Could you clarify, when this doesn’t work?

I’m not able to replicate it now… Not sure what that’s about? Perhaps I was confused… Deleting my previous post.

1 Like

The problem I have with using m.parameters() or m.named_parameters() is that it does not return the affine parameters of the batch norm layers if I am not mistaken.
Could you explain how to set our batch norm layers to requires_grad=False?
Thanks!

No, the trainable affine parameters from batchnorm layers will be returned as seen here:

bn = nn.BatchNorm2d(3, affine=False)
print(dict(bn.named_parameters()))
# {}

bn = nn.BatchNorm2d(3, affine=True)
print(dict(bn.named_parameters()))
# {'weight': Parameter containing:
# tensor([1., 1., 1.], requires_grad=True), 'bias': Parameter containing:
# tensor([0., 0., 0.], requires_grad=True)}

However, named_parameters will not return the running stats as these are buffers:

bn = nn.BatchNorm2d(3, affine=True)
print(dict(bn.named_buffers()))
# {'running_mean': tensor([0., 0., 0.]), 'running_var': tensor([1., 1., 1.]), 'num_batches_tracked': tensor(0)}

Set the .requires_grad attribute of the trainable, affine parameters to False.

1 Like