Modify gradients of an FSDP model

Hi everyone,
I have an FSDP model which has zeros in some of the torch.nn.Linear.weight parameters. During the training I would like to keep those parameters fixed to zeros, and to zero-out their gradients as well. The specific use-case is: I am loading a pruned model and I want to fine-tune it with FSDP while keeping the pruning mask fixed.

To achieve this I need to do two things:

  1. multiply parameters with the mask before the forward pass (so that all pruned weights remain pruned),
  2. multiply gradients of pruned parameters after the backward pass (so that gradients of pruned weights are zeros)

In the standard DDP training I would achieve this by:

  1. registering forward pre-hook on torch.nn.Linear modules and multiplying weights with the mask before each forward pass,
  2. registering a hook on the parameter torch.nn.Linear.weight and multiplying its gradient with the mask.

For example:

def keep_param_pruned(mask, module, input):
    with torch.no_grad():

def keep_grad_pruned(mask, grad):
    return grad.mul_(

for n, m in model.named_modules():
    if isinstance(m, torch.nn.Linear):
        mask = m.weight > threshold
        m.register_forward_pre_hook(partial(keep_param_pruned, mask))
        m.weight.register_hook(partial(keep_grad_pruned, mask))

However, I am struggling to modify this idea to work with FSDP. For forward pass, it seems like my hook is being called but unfortunately the hook for gradients is not and I can’t zero them out after the backward pass. Any suggestions/ideas on what I am doing wrong or if there is a simpler way to achieve this without playing with hooks?