Masking module parameters

So here you have two parameters in your module:

  • original weights of the module
  • mask_params that are used to compute the mask

I would modify the module to have all the right Parameters and recompute weight for each forward.

# Example for a Linear (handle bias the same way if you want them)
mod = nn.Linear(10, 10, bias=False)
mod.mask_params = nn.Parameter(whatever)
mod.original_weight = mod.weight
del mod.weight # re-populate it for each forward using a forward pre hook
def repopulate_weight(mod, _):
  mask = sample_mask(mod.mask_params)
  mod.weights = mod.original_weight * mask
mod.register_forward_pre_hook(repopulate_weight)

# Use `mod` as any Linear now
1 Like