Hi everyone,
I have an FSDP model which has zeros in some of the torch.nn.Linear.weight
parameters. During the training I would like to keep those parameters fixed to zeros, and to zero-out their gradients as well. The specific use-case is: I am loading a pruned model and I want to fine-tune it with FSDP while keeping the pruning mask fixed.
To achieve this I need to do two things:
- multiply parameters with the mask before the forward pass (so that all pruned weights remain pruned),
- multiply gradients of pruned parameters after the backward pass (so that gradients of pruned weights are zeros)
In the standard DDP training I would achieve this by:
- registering forward pre-hook on
torch.nn.Linear
modules and multiplying weights with the mask before each forward pass, - registering a hook on the parameter
torch.nn.Linear.weight
and multiplying its gradient with the mask.
For example:
def keep_param_pruned(mask, module, input):
with torch.no_grad():
module.weight.data.mul_(mask.to(module.weight.device))
def keep_grad_pruned(mask, grad):
return grad.mul_(mask.to(grad.device))
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
mask = m.weight > threshold
m.register_forward_pre_hook(partial(keep_param_pruned, mask))
m.weight.register_hook(partial(keep_grad_pruned, mask))
However, I am struggling to modify this idea to work with FSDP. For forward pass, it seems like my hook is being called but unfortunately the hook for gradients is not and I can’t zero them out after the backward pass. Any suggestions/ideas on what I am doing wrong or if there is a simpler way to achieve this without playing with hooks?