How to assign numpy.array to the state_dict and update only part of the filters

I have found some perfect answers to my question, but I am really afraid of latent bugs such as the Autograd problem, so could you please kindly tell me

  1. if the line4 in the following codes is ok? Would the torch.from_numpy create a new node in graph and make the grad of ‘conv.weight’ change because I use torch.from_numpy()?
  2. Supposed temp_state_dict['conv.weight'].shape==[16,3,3,3], what can I do to make the gradinet of first 8 filters to be 0 thus stop only part of the weight updating?
temp_state_dict = model.state_dict()  # line1
weight_numpy = temp_state_dict['conv.weight'].cpu().numpy() # line 2
# ....Processing weight_numpy such as Pruning.... # line3
temp_state_dict['conv.weight'] = torch.from_numpy(weight_numpy) # line4
model.load_state_dict(temp_state_dict) # line5

I’m not sure I understand the use case completely, but loading parameters into the model via a new or manipulated state_dict is independent from the computation graph and should be done before or after the training step. I.e. performing the forward pass will use the current parameter set (P0) to compute the intermediate forward activations (A0), which will be used to compute the gradients (G0). Changing the parameters to P1 would try to use A0 to calculate the gradients, which is wrong, so could you explain the use case a bit more?

Sorry for my vague description. Actually I asked two seperated problem, which may mislead you.

  1. The background of the first question is fitler pruning, specifically SOFT Filter Pruning, for which I would like to write code as shown bellow. As we can see, before each trainging epoch, I would manipulate the weight and pass the new_state_dict to the model. The first question is that is my code OK according to your preliminary judgment? Would this cause any latent bugs such as AutoGradient problem?
for loop in training_epochs:
        # manipulate conva.weight of state_dict, then load the new_state_dict
        new_state_dict = model.state_dict() 
        # an manipulated example. The key here is assign a from_numpy to the weight
        new_state_dict['conva.weight'] = torch.from_numpy(np.zeros(new_state_dict['conva.weight'].shape )

  1. The second question is seperate from the former one. Said that state_dict['conva.weight'].shape==[16,3,3,3], i.e., there are 16 filters. Now I want to update only the last 8 filters in the training period while keeping the first 8 filters the same and unchagned, and my method is to make the gradient of the first 8 filters zero in every epoch as follows.
    Is the code OK? (I am afraid of latent bugs)
for loop in training_epochs:
        for item in model.named_parameters():
             if item[0]=='conva.weight':
                     grad_mask = np.ones(item.grad.shape)
                     grad_mask[8:,:,:,:] = 0
            = grad_mask
  1. Yes, your approach looks wrong as the new parameter assignment would not keep the gradients, so you should copy them manually back, if you want to apply the previously calculated gradients to a new parameter set (assuming that’s what the linked method really does).

  2. Don’t use the .data attribute, but wrap the code in a with torch.no_grad() block and set the weight.grad[8:] to 0.

Could you please tell me what I should copy manually back? The code in question 1 was wrote acrroding to your advice here, and I think I only change the data of a particular weight while the other parameters are the same still, thus I don’t know what to copy back.

Your approach could work, but I don’t think there is a guarantee to keep other data, such as the gradients, after loading a new state_dict, so something like this should work:

model = models.resnet18()
out = model(torch.randn(1, 3, 224, 224))
loss = out.mean()

grads = {}
for name, param in model.named_parameters():
    grads[name] = param.grad.clone()
# change the parameters here
sd = models.resnet18().state_dict()
sd['fc.weight'] = torch.rand_like(sd['fc.weight'])

# reload grads
for name, param in model.named_parameters():
    param.grad = grads[name]

Thanks for your help. Your are the light of this community.

1 Like