Updating the parameters of a few nodes in a pre-trained network during training

Hi,

I am really new to PyTorch and was wondering if there is a way to specify only a subset of neurons (of a particular layer) to update during training and freeze the rest. Say, update only 2500 nodes of the 4096 in AlexNet, FC7. param.requires_grad seems to apply to all the neurons.

Appreciate your inputs.

I add simple codes. If I understand your question exactly, It will be helpful

D_parameters = [
    {'params': model.fc1.parameters()},
    {'params': model.fc2.parameters()}
] # define a part of parameters in model

optimizerD = torch.optim.Adam(D_parameters, lr=learning_rate)

for epoch in range(training_epoch):
    ....
    optimizerD.zero_grad() # zero_grad only D_parameters
    # do something 
    loss.backward() # calculate grads of all
    optimizerD.step() # update only D_parameters

5 Likes

Hi,

Appreciate your prompt response but this is not exactly what I am looking for. I want to update only a subset of the fc1 parameters, for example. If we update D_parameters, in your example, then all weights from the 4096 nodes of FC1 get updated. What I want is to update a subset of it, say 2900 of them (that I have as a list).

I think you should reconstruct trained model, and then apply updated a few nodes.

import torch
import torch.nn as nn
from torchvision import models

original_model = models.alexnet(pretrained=True)

class AlexNetConv4(nn.Module):
            def __init__(self):
                super(AlexNetConv4, self).__init__()
                self.features = nn.Sequential(
                    # stop at conv4
                    *list(original_model.features.children())[:-3]
                )
                self.fc1 = nn.Linear(2900, n_outpout, bias = True)
                self.fc2 = nn.Linear(4096-2900, n_outpout, bias = True)

            def forward(self, x):
                x = self.features(x)
                x1 = self.fc1(x[:,:2900])
                x2 = self.fc2(x[:,2900:])
                ....
                return x

model = AlexNetConv4()
1 Like

you could use a backward hook on the output on fc2 to zero the gradients going through parts that you want to filter.

For example:

m = nn.Linear(1024, 4096)
input = Variable(torch.randn(128, 1024), requires_grad=True)

out = m(input)
def my_hook(grad):
    grad_clone = grad.clone()
    grad_clone[:, 2500:] = 0
    return grad_clone
h = out.register_hook(my_hook) # zeroes the gradients wrt the outputs for everything that's not 0 to 2500 over all mini-batches

out.backward(grads)

Edit: edited to incorporate @fmassa’s answer below

5 Likes

Good!! I knew new practical usages of hooking by your answer. Thank you

I think it’s better to avoid modifying the gradients in place, but instead return a new gradient, as explained in the docs.
So a modified version would be to pass a hook such as:

def my_hook(grad):
    grad_clone = grad.clone()
    grad_clone[:, 2500:] = 0
    return grad_clone
2 Likes
m = nn.Linear(1024, 4096)
input = Variable(torch.randn(128, 1024), requires_grad=True)

out = m(input)
h = out.register_hook(my_hook(grad)) 

out.backward(grads)

def my_hook(grad):
    grad_clone = grad.clone()
    grad_clone[:, 2500:] = 0

Like this?

You also need to return grad_clone in your hook

For sure.

Thanks,

David

Hey,

Just a small nuance question:
out.backward(grads)
where from do we take grads variable? Can we simply write out.backward() ?

It depends on your loss.
If your loss is a scalar, it is ok to write loss.backward(). If it is a tensor with length > 1, you need to write loss.backward(grads), and the grads has the same size with your out. For example, loss.backward(torch.ones_like(loss.data)).

1 Like

What if I need the data of Tensor in the hooked_fn. As far as I understand, it will only have access to grads of the Tensor on which it is registered. I would like to modify gradients based on the current data of Tensor. The end goal is to implement Inverting Gradients given in the paper “Deep Reinforcement Learning in Parameterized Action Space”.

EDIT:
We have access to variables defined outside the scope of hooked_fn. Hence, we can simply do data = hooked_tensor.clone().numpy() inside the hooked_fn. Hence new_grad = some_func(data, grad).