For my purposes, I want only a specific set of weights in Conv2D kernels to update for a batch. Lets say for kernel w/ dimensions (40,C_In,k,k) I want batch 1 to cause the update of only kernel weights[0: 20] and batch 2 to cause the update of kernel weights[20:]. For this , I define the following hooks:
def hook_batch_1 (grad):
grad_clone = grad.clone()
grad_clone[20:] = 0
print("Hook for Batch 1 Working")
return grad_clone
def hook_batch_2 (grad):
grad_clone = grad.clone()
grad_clone[:20] = 0
print("Hook for Batch 2 Working")
return grad_clone
I have a list where the corresponding handles are stored:
tensor_hook_handles = [ ]
tensor_hook_handles.append(layer.weight.register_hook(hook_batch_1))
and once batch1 is processed, I then call the .Remove() method on the handles and proceed to set the new hooks for batch2
for handle in tensor_hook_handles:
handle.remove()
tensor_hook_handles.clear() #empty list
tensor_hook_handles.append(layer.weight.register_hook(hook_batch_2))
However,when the new hook is set, the weights that get updated are still the ones that hook_batch_1 defined. Interestingly, the print I get is “Hook for Batch 2 Working” which indicates that the hooks were set but they are having no impact on the gradients given that the previous hook still prevails somehow.
Any ideas on how to fix this or on what I may be missing here? I am under the impression that Tensor hooks are modifiable as mentioned here: backward_hook triggered despite RemovableHandle.remove() · Issue #25723 · pytorch/pytorch · GitHub
Best