How can I disable all layers gradient expect the last layer in Pytorch?

Hello All, I’m trying to fine-tune a resnet18 model.

I want to freeze all layers except the last one. I did

resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(512, 10) 
for param in resnet18.parameters():
    param.requires_grad = False 

However, doing

for param in resnet18.fc.parameters():
    param.requires_grad = True

Fails. How can I set a specific layers parameters to have requires_grad to True?

Thank you all in advance

Note:

I specifically don’t want to swap the order of assigning a new layer with setting all the grads to false

I want to learn how this specific thing can be done.

Your code works.
After running your code snippets, you can print the requires_grad attributes:

for name, param in resnet18.named_parameters():
    print(name, param.requires_grad)

which shows, that fc.weight and fc.bias both require the gradient.
You will also get a valid gradients in these layers:

resnet18(torch.randn(1, 3, 224, 224)).mean().backward()
for name, param in resnet18.named_parameters():
    print(name, param.grad)

Thank you very much, but the code I gave produces an error. it says:
fc doesnt have any attribute named parameters()
so instead I did :

for _, param in resnet18.fc._parameters.items(): 
    print(param.requires_grad)
    param.requires_grad = True

and interestingly for this to work I have to do :

for module in resnet18.modules():
    if module._get_name() != 'Linear':
        print('layer: ',module._get_name())
        for param in module.parameters():
            param.requires_grad_(False)
    elif module._get_name() == 'Linear':
        for param in module.parameters():
            param.requires_grad_(True)

again if I just do :

for module in resnet18.modules():
    if module._get_name() != 'Linear':
        print('layer: ',module._get_name())
        for param in module.parameters():
            param.requires_grad_(False)

and print

for param in resnet18.parameters():
    print(param.requires_grad)

all parameters are set as False!
This is really puzzeling.

I would recommend to stick to the named_parameters approach, as in your approach resnet18.modules() will also return fc.weight and fc.bias, which do not contain the 'Linear' name in it.

Does this code raise this error:

for name, param in resnet18.fc.named_parameters():
    print(name, param.requires_grad)

If so, could you post your pytorch and torchvision versions, as I would like to have a look at it?

Its very weird! both your code and also!fc.parameters() are now working just fine!!!
This has got me confused for two days! and now they are just working fine!
I don’t know what could have caused this! or I may have pretty much made a mistake!
By the way I am running Pytorch 1.0!
Any way thanks a gazillion times that was a tremendous help.
By the way do you mind if I ask you to kindly have a look here as well?

Good to hear, it’s working now!
If you are running a Jupyter notebook, make sure to run all previous cells, as it’s easy to forget about old variables etc. :wink:

Yes, it was on Jupyter,
One more thing I was experimenting with different ways of doing this and wrote this :

for k, p in resnet18.fc._parameters.items():
    p.requires_grad = True

which works but I tried to changed it again and wrote it this way this time:

(p.requires_grad_(True) for k,p in resnet18.fc._parameters.items())

which failed miserably!
I expected this to also work since I’m using the inplace operator (requires_grad_) but it doesnt! do you know why this is not working?

In your second code snippet you are creating a Python Generator, which will be lazily evaluated.
Your code works, if you execute the generator or use a list comprehension instead.

resnet18 = models.resnet18()
for param in resnet18.fc.parameters():
    print(param.requires_grad)

# 1
gen = (p.requires_grad_(False) for k,p in resnet18.fc._parameters.items())
next(gen)
next(gen)
# 2
list((p.requires_grad_(False) for k,p in resnet18.fc._parameters.items()))
# 3    
[p.requires_grad_(False) for k,p in resnet18.fc._parameters.items()]

for param in resnet18.fc.parameters():
    print(param.requires_grad)

Thanks a quintilion times sir :slight_smile:
God bless you and have a fantastic weekend