How do I manually “mask out” some weights (element-level) from being penalized by weight_decay in PyTorch optimizer (e.g. optim.SGD)? For example, the weights that are not involved in the forward computation.
Here is a minimum runnable example (MRE) to explain what I wish to do:
import torch from torch import nn from torch import optim # parameter weight = nn.Parameter(torch.ones((256, 256, 3, 3))) bias = nn.Parameter(torch.ones((256))) print("before weight_decay:", weight[128, 128, :, :].detach().numpy()) # build optimizer optimizer = optim.SGD([weight, bias], lr=1e-2, weight_decay=1e-2) # large wd for demo. optimizer.zero_grad() # cropped from parameter (weight/bias) # Note that weight[128, 128, :, :] is not involed in the forward computation weight_ = weight[:128, :128, :, :] bias_ = bias[:128] # one training step x = torch.ones((1, 128, 14, 14)) y = nn.functional.conv2d( x, weight_, bias_, stride=1, padding=1, dilation=1, groups=1) l = y.sum() l.backward() optimizer.step() # TODO: prevent this part from being penalized by weight_decay since they are not used for computation print("after weight_decay:", weight[128, 128, :, :].detach().numpy())