So here you have two parameters in your module:
- original weights of the module
- mask_params that are used to compute the mask
I would modify the module to have all the right Parameters and recompute weight for each forward.
# Example for a Linear (handle bias the same way if you want them)
mod = nn.Linear(10, 10, bias=False)
mod.mask_params = nn.Parameter(whatever)
mod.original_weight = mod.weight
del mod.weight # re-populate it for each forward using a forward pre hook
def repopulate_weight(mod, _):
mask = sample_mask(mod.mask_params)
mod.weights = mod.original_weight * mask
mod.register_forward_pre_hook(repopulate_weight)
# Use `mod` as any Linear now