I want to update my model parameters with a trainable mask before the forward pass. To accomplish this, I tried to use a hook that does the following:
- copy network weights into a parameter called orig
- make a mask if one does not exist already
- use the masked weights as the weights for the forward pass
However, I am unable to get the gradients to flow back to the mask. Can anyone point me to where I am going wrong?
import torch
import torch.nn as nn
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(1234)
self.lin = nn.Linear(4, 1, bias=False)
def forward(self, x):
return self.lin(x)
def mask_hook(module: torch.nn.Module, input):
parameter_path = "lin.weight"
module_path, _, name = parameter_path.rpartition(".") # resolve param path
sub_mod = module.get_submodule(module_path)
param = getattr(sub_mod, name)
del sub_mod._parameters[name]
mask_name = name + "_mask" # lets make a mask param
orig_name = name + "_orig" # lets also make a parameter to store original params
if not hasattr(sub_mod, mask_name): # no mask present
sub_mod.register_parameter( # create and register mask
mask_name,
nn.Parameter(torch.ones_like(param)),
)
mask = getattr(sub_mod, mask_name)
setattr(sub_mod, orig_name, nn.Parameter(param.data)) # copy param into orig
orig = getattr(sub_mod, orig_name)
sub_mod.register_parameter(name, nn.Parameter(orig * mask)) # put masked params into param
## dummy training code
toy_model = ToyModel()
hook_handle = toy_model.register_forward_pre_hook(mask_hook)
optimizer = torch.optim.SGD(toy_model.parameters(), lr=1.0)
N = 5
inps = torch.arange(N * 4, dtype=torch.float32).view(N, 4)
tgts = torch.arange(N * 1, dtype=torch.float32).view(N, 1)
training_loader = torch.utils.data.TensorDataset(inps, tgts)
for i, (x, y) in enumerate(training_loader):
pred = toy_model(x)
loss = nn.MSELoss()(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
print([(name, torch.norm(param).item()) for name, param in toy_model.named_parameters()])