How can i freeze weights in element-wise

The conceptually clean way to fix some part of weights is to have buffers (with self.register_buffer('weight_update_mask', the_mask) in the module initialization for the mask of what should be updated and the fixed weights and then in the forward use weight = torch.where(self.weight_update_mask, self.weight_param, self.weight_fixed).

Now you might get by with with torch.no_grad(): param[fixed_mask] = 0 instead. But I would view this as an optimization attempt for the former and it is not true that it is the same when you consider optimizers that don’t work fully elementwise but take the weight/gradient in its entirety into consideration (e.g. LARS/LAMB etc.).

Best regards

Thomas

4 Likes