How do I dynamically swich on/off weight_decay

TL:DR

How do I manually “mask out” some weights (element-level) from being penalized by weight_decay in PyTorch optimizer (e.g. optim.SGD)? For example, the weights that are not involved in the forward computation.

Here is a minimum runnable example (MRE) to explain what I wish to do:

import torch
from torch import nn
from torch import optim
# parameter
weight = nn.Parameter(torch.ones((256, 256, 3, 3)))
bias = nn.Parameter(torch.ones((256)))
print("before weight_decay:", weight[128, 128, :, :].detach().numpy())
# build optimizer
optimizer = optim.SGD([weight, bias], lr=1e-2, weight_decay=1e-2)  # large wd for demo.
optimizer.zero_grad()
# cropped from parameter (weight/bias)
#   Note that weight[128, 128, :, :] is not involed in the forward computation
weight_ = weight[:128, :128, :, :]
bias_ = bias[:128]
# one training step
x = torch.ones((1, 128, 14, 14))
y = nn.functional.conv2d(
            x, weight_, bias_, stride=1, padding=1, dilation=1, groups=1)
l = y.sum()
l.backward()
optimizer.step()
# TODO: prevent this part from being penalized by weight_decay since they are not used for computation
print("after weight_decay:", weight[128, 128, :, :].detach().numpy())

I think the easier approach would be to calculate the weight decay for the “used” parameters manually instead of trying to disable it for “unused” parameters.
Here is a simple example of adding a custom regularization.