Freezing a part of the tensor

Hi!

I am trying to freeze all my network weights except some of the output embeddings (nn.Linear columns) using register_backward_hook.

This discussion was very helpful, but I still do not understand what is the right way to do it.

Backward hook for the layer has grad_input and grad_output. As far as I saw, grad_input was always equal to grad_output (this is the final layer of the model) and they both have shape (bach_size, seq_len, vocab_size).

I have three questions:

  1. Ideally, I would like to modify the value of the weight gradient (hidden_size, vocab_size) and I tried to access it inside the hook via self.weight.grad. However it is always None even though this is happening in the function registered via .register_backward_hook so it is called after .backward(). Why does this happen and how to access the gradient?

  2. Ok, let’s forget about the weight gradient tensor. I can modify all values corresponding to the particular embedding in grad_input right? (According to the documentation though, I should not modify, but to copy them and return new values).
    So it would look like this:

def hook(self, grad_input, grad_output):
    new_grads = grad_input[0].detach().clone()
    new_grads[:, :, freeze_ids] = 0.
    return new_grads, grad_input[1]

Should I modify grad_input or grad_output in this case?

  1. Is ADAM going to update the values even after we zeroed the gradients? ADAM has momentum values which depend on the previous gradients and not on the current gradient and my setup implies that ADAM will have nonzero betas for these parameters.

UPD:

  1. Still have no idea about question 1.
  2. This hook seem to work with SGD (as far as my tests go)
  3. ADAM or SGD with momentum do update these values if you initialize the hook in the middle of the training process. Help me plz =(

register_backward_hook might nor properly work for “complex” modules as described in the Warning box in the docs.
I would thus recommend to use param.register_hook directly, which will only get the gradient as its input.