How can I disable all layers gradient expect the last layer in Pytorch?

Hello All, I’m trying to fine-tune a resnet18 model.

I want to freeze all layers except the last one. I did

resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(512, 10) 
for param in resnet18.parameters():
    param.requires_grad = False 

However, doing

for param in resnet18.fc.parameters():
    param.requires_grad = True

Fails. How can I set a specific layers parameters to have requires_grad to True?

Thank you all in advance

Note:

I specifically don’t want to swap the order of assigning a new layer with setting all the grads to false

I want to learn how this specific thing can be done.

1 Like

Your code works.
After running your code snippets, you can print the requires_grad attributes:

for name, param in resnet18.named_parameters():
    print(name, param.requires_grad)

which shows, that fc.weight and fc.bias both require the gradient.
You will also get a valid gradients in these layers:

resnet18(torch.randn(1, 3, 224, 224)).mean().backward()
for name, param in resnet18.named_parameters():
    print(name, param.grad)
3 Likes

Thank you very much, but the code I gave produces an error. it says:
fc doesnt have any attribute named parameters()
so instead I did :

for _, param in resnet18.fc._parameters.items(): 
    print(param.requires_grad)
    param.requires_grad = True

and interestingly for this to work I have to do :

for module in resnet18.modules():
    if module._get_name() != 'Linear':
        print('layer: ',module._get_name())
        for param in module.parameters():
            param.requires_grad_(False)
    elif module._get_name() == 'Linear':
        for param in module.parameters():
            param.requires_grad_(True)

again if I just do :

for module in resnet18.modules():
    if module._get_name() != 'Linear':
        print('layer: ',module._get_name())
        for param in module.parameters():
            param.requires_grad_(False)

and print

for param in resnet18.parameters():
    print(param.requires_grad)

all parameters are set as False!
This is really puzzeling.

1 Like

I would recommend to stick to the named_parameters approach, as in your approach resnet18.modules() will also return fc.weight and fc.bias, which do not contain the 'Linear' name in it.

Does this code raise this error:

for name, param in resnet18.fc.named_parameters():
    print(name, param.requires_grad)

If so, could you post your pytorch and torchvision versions, as I would like to have a look at it?

2 Likes

Its very weird! both your code and also!fc.parameters() are now working just fine!!!
This has got me confused for two days! and now they are just working fine!
I don’t know what could have caused this! or I may have pretty much made a mistake!
By the way I am running Pytorch 1.0!
Any way thanks a gazillion times that was a tremendous help.
By the way do you mind if I ask you to kindly have a look here as well?

1 Like

Good to hear, it’s working now!
If you are running a Jupyter notebook, make sure to run all previous cells, as it’s easy to forget about old variables etc. :wink:

1 Like

Yes, it was on Jupyter,
One more thing I was experimenting with different ways of doing this and wrote this :

for k, p in resnet18.fc._parameters.items():
    p.requires_grad = True

which works but I tried to changed it again and wrote it this way this time:

(p.requires_grad_(True) for k,p in resnet18.fc._parameters.items())

which failed miserably!
I expected this to also work since I’m using the inplace operator (requires_grad_) but it doesnt! do you know why this is not working?

In your second code snippet you are creating a Python Generator, which will be lazily evaluated.
Your code works, if you execute the generator or use a list comprehension instead.

resnet18 = models.resnet18()
for param in resnet18.fc.parameters():
    print(param.requires_grad)

# 1
gen = (p.requires_grad_(False) for k,p in resnet18.fc._parameters.items())
next(gen)
next(gen)
# 2
list((p.requires_grad_(False) for k,p in resnet18.fc._parameters.items()))
# 3    
[p.requires_grad_(False) for k,p in resnet18.fc._parameters.items()]

for param in resnet18.fc.parameters():
    print(param.requires_grad)
2 Likes

Thanks a quintilion times sir :slight_smile:
God bless you and have a fantastic weekend

1 Like