Masking module parameters


how would one effectively mask the parameters of a module without losing neither their link to the optimizer?
More in detail:

Approach 1:

for weights in model.parameters():
    mask = sample_mask(some_arguments...)  # this comes from weights of a wrapper module that needs to be trained jointly *= mask
loss = loss_func(model(data), target)
for backup, weights in zip(backups, model.parameters()): = backup

This seemed good, but the weights of model lose the grad_fn attribute that the mask accumulated for its generation process and that I need for my training!

Approach 2:

for weights in model.parameters():
    mask = sample_mask(some_arguments...) *= mask
    weights.grad_fn = mask.grad_fn  # ERROR, grad_fn is not writable!
loss = loss_func(model(data), target)

Doesn’t work as pytorch forbids assigning to grad_fn (why? It may not be common or best practice, but certainly has its use cases). If it did work, it would be my preferred approach i guess, as it makes clear what I need as well.

Approach 3:

for weights in model.parameters():
    del weights from optimizer
    mask = sample_mask(some_arguments...)
    weights = weights * mask
    optimizer.add_weights(weights) in the right place somehow
loss = loss_func(model(data), target)

Basically very hacky and unclear as to how to delete only the exact weights that I want to see deleted, and then afterwards assign them in the correct spot again…

So since approach 3 appears to be pretty unthinkable for a bigger composite model, I am a bit at a loss at how to simply mask parameters and still train the parameters and the mask generation process jointly. Tips would be appreciated


The nn.Parameter are always leaf nodes. Meaning that they cannot have .grad_fn field set.
This is because the optimizers use the .grad field to perform updates on the weights, but if it’s not a leaf Tensor, its .grad field won’t be updated and so the optimizer won’t be able to perform update.
So do you actually want to perform updates both on the weights and parameters (naming from your code)? Or just parameters?

Hey thanks for the reply,

sorry the naming was confusing. I want gradient updates on:

  1. The parameters behind the mask sampling process (to better sample the appropriate mask)
  2. the weights of the underlying model (here named simply model with its weights weights), that are masked by mask.

I will edit the confusing double parameters naming part from the code quickly.
Does this clear it up?

So here you have two parameters in your module:

  • original weights of the module
  • mask_params that are used to compute the mask

I would modify the module to have all the right Parameters and recompute weight for each forward.

# Example for a Linear (handle bias the same way if you want them)
mod = nn.Linear(10, 10, bias=False)
mod.mask_params = nn.Parameter(whatever)
mod.original_weight = mod.weight
del mod.weight # re-populate it for each forward using a forward pre hook
def repopulate_weight(mod, _):
  mask = sample_mask(mod.mask_params)
  mod.weights = mod.original_weight * mask

# Use `mod` as any Linear now
1 Like

Hey alban,
thanks a lot for the help. The hooks were a good idea. Ive now set a forward pre hook to sample and backup the .data and a backward hook to reinstall the backups for gradient updates.