Tensor Hook triggered despite Handle.remove() when Replacing Hook with new one

For my purposes, I want only a specific set of weights in Conv2D kernels to update for a batch. Lets say for kernel w/ dimensions (40,C_In,k,k) I want batch 1 to cause the update of only kernel weights[0: 20] and batch 2 to cause the update of kernel weights[20:]. For this , I define the following hooks:

def hook_batch_1 (grad):
    grad_clone = grad.clone()
    grad_clone[20:] = 0
    print("Hook for Batch 1 Working")
    return grad_clone  

def hook_batch_2 (grad):
    grad_clone = grad.clone()
    grad_clone[:20] = 0
    print("Hook for Batch 2 Working")    
    return grad_clone  

I have a list where the corresponding handles are stored:

tensor_hook_handles = [ ]
tensor_hook_handles.append(layer.weight.register_hook(hook_batch_1))

and once batch1 is processed, I then call the .Remove() method on the handles and proceed to set the new hooks for batch2

for handle in tensor_hook_handles:
    handle.remove()
tensor_hook_handles.clear() #empty list
tensor_hook_handles.append(layer.weight.register_hook(hook_batch_2))

However,when the new hook is set, the weights that get updated are still the ones that hook_batch_1 defined. Interestingly, the print I get is “Hook for Batch 2 Working” which indicates that the hooks were set but they are having no impact on the gradients given that the previous hook still prevails somehow.

Any ideas on how to fix this or on what I may be missing here? I am under the impression that Tensor hooks are modifiable as mentioned here: backward_hook triggered despite RemovableHandle.remove() · Issue #25723 · pytorch/pytorch · GitHub

Best

Hi,

Your two hooks are doing the exact same thing except the printing. So it is expected that the only difference you would see is the print no?

Hello,

The hooks are targeting distinct non overlapping partitions of the kernels in the output channel dimensions so that’s not the issue. Notice the difference in the indexing grad_clone[:20] vs grad_clone[20:]. If I begin with hook_batch_2 and then try to set hook_batch_1 a different set of weights (the ones that correspond to kernel_weight[:20]) get updated so it’s a matter of which hook gets set first.
After the first hook is set, independently of which one it is, when I set the following one, the behavior of the first one remains despite the print indicating that the new one has been set.

Thanks for the clarification.

However,when the new hook is set, the weights that get updated are still the ones that hook_batch_1 defined

one thing that can cause this is that some optimizer (adam, sgd with momentum, etc) actually still update the weights when the gradients are full of 0s. Because of all the momentum terms.
So masking out the gradient won’t be enough to prevent the weights from changing when you do your optimizer step.

I can confirm it was the choice of optimizer which caused the behavior, both adam and sgd with momentum would behave as I previously described. Sgd with momentum=0 does what I want it to do.

Thank you very much for the help!

1 Like