how would one effectively mask the parameters of a module without losing neither their link to the optimizer?
More in detail:
Approach 1:
for weights in model.parameters():
backups.append(weights.clone().detach().data)
mask = sample_mask(some_arguments...) # this comes from weights of a wrapper module that needs to be trained jointly
weights.data *= mask
loss = loss_func(model(data), target)
for backup, weights in zip(backups, model.parameters()):
weights.data = backup
This seemed good, but the weights of model lose the grad_fn attribute that the mask accumulated for its generation process and that I need for my training!
Approach 2:
for weights in model.parameters():
backup.append(weights.clone().detach().data)
mask = sample_mask(some_arguments...)
weights.data *= mask
weights.grad_fn = mask.grad_fn # ERROR, grad_fn is not writable!
loss = loss_func(model(data), target)
Doesn’t work as pytorch forbids assigning to grad_fn (why? It may not be common or best practice, but certainly has its use cases). If it did work, it would be my preferred approach i guess, as it makes clear what I need as well.
Approach 3:
for weights in model.parameters():
del weights from optimizer
mask = sample_mask(some_arguments...)
weights = weights * mask
optimizer.add_weights(weights) in the right place somehow
loss = loss_func(model(data), target)
Basically very hacky and unclear as to how to delete only the exact weights that I want to see deleted, and then afterwards assign them in the correct spot again…
So since approach 3 appears to be pretty unthinkable for a bigger composite model, I am a bit at a loss at how to simply mask parameters and still train the parameters and the mask generation process jointly. Tips would be appreciated
The nn.Parameter are always leaf nodes. Meaning that they cannot have .grad_fn field set.
This is because the optimizers use the .grad field to perform updates on the weights, but if it’s not a leaf Tensor, its .grad field won’t be updated and so the optimizer won’t be able to perform update.
So do you actually want to perform updates both on the weights and parameters (naming from your code)? Or just parameters?
I would modify the module to have all the right Parameters and recompute weight for each forward.
# Example for a Linear (handle bias the same way if you want them)
mod = nn.Linear(10, 10, bias=False)
mod.mask_params = nn.Parameter(whatever)
mod.original_weight = mod.weight
del mod.weight # re-populate it for each forward using a forward pre hook
def repopulate_weight(mod, _):
mask = sample_mask(mod.mask_params)
mod.weights = mod.original_weight * mask
mod.register_forward_pre_hook(repopulate_weight)
# Use `mod` as any Linear now
Hey alban,
thanks a lot for the help. The hooks were a good idea. Ive now set a forward pre hook to sample and backup the .data and a backward hook to reinstall the backups for gradient updates.