Updating Features Before Forward using Prehooks

I want to update my model parameters with a trainable mask before the forward pass. To accomplish this, I tried to use a hook that does the following:

  1. copy network weights into a parameter called orig
  2. make a mask if one does not exist already
  3. use the masked weights as the weights for the forward pass

However, I am unable to get the gradients to flow back to the mask. Can anyone point me to where I am going wrong?

import torch
import torch.nn as nn

class ToyModel(torch.nn.Module):
  def __init__(self):
    self.lin = nn.Linear(4, 1, bias=False)

  def forward(self, x):
    return self.lin(x)

def mask_hook(module: torch.nn.Module, input):
    parameter_path = "lin.weight"
    module_path, _, name = parameter_path.rpartition(".")  # resolve param path
    sub_mod = module.get_submodule(module_path)

    param = getattr(sub_mod, name)
    del sub_mod._parameters[name]
    mask_name = name + "_mask" # lets make a mask param
    orig_name = name + "_orig" # lets also make a parameter to store original params

    if not hasattr(sub_mod, mask_name): # no mask present
      sub_mod.register_parameter(  # create and register mask

    mask = getattr(sub_mod, mask_name)
    setattr(sub_mod, orig_name, nn.Parameter(param.data)) # copy param into orig
    orig = getattr(sub_mod, orig_name)
    sub_mod.register_parameter(name, nn.Parameter(orig * mask)) # put masked params into param

## dummy training code

toy_model = ToyModel()
hook_handle = toy_model.register_forward_pre_hook(mask_hook)

optimizer = torch.optim.SGD(toy_model.parameters(), lr=1.0)

N = 5
inps = torch.arange(N * 4, dtype=torch.float32).view(N, 4)
tgts = torch.arange(N * 1, dtype=torch.float32).view(N, 1)
training_loader = torch.utils.data.TensorDataset(inps, tgts)

for i, (x, y) in enumerate(training_loader):
  pred = toy_model(x)
  loss = nn.MSELoss()(pred, y)
  # Backpropagation
  print([(name, torch.norm(param).item()) for name, param in toy_model.named_parameters()])

It seems you are using a forward_pre_hook to delete the original parameter and replace it with a new one in each forward step.
This will break the reference to the optimizer and the optimizer.step() call will not update the originally passed parameter as it was already deleted.

That could definitely be one issue. I tried to simplify the problem down further into seeing whether gradient flowed when multiplying 2 parameters, and found that it does not:

import torch
import torch.nn as nn

class PlayModel(torch.nn.Module):
  def __init__(self):
    self.a = nn.Parameter(torch.tensor([[1.0, 2.0]]), requires_grad=True)
    self.b = nn.Parameter(torch.tensor([[3.0, 4.0]]), requires_grad=True)

    self.lin = nn.Linear(2, 1, bias=False)

  def forward(self, x):
    self.lin.weight = nn.Parameter(self.a * self.b, requires_grad=True)

    return self.lin(x)

def train():
  play_model = PlayModel()
  optimizer = torch.optim.Adam(play_model.parameters(), lr=1.0)
  loss_fn = nn.MSELoss()

  n = 10
  inps = torch.arange(n * 2, dtype=torch.float32).view(n, 2)
  tgts = torch.arange(n * 1, dtype=torch.float32).view(n, 1)
  training_loader = torch.utils.data.TensorDataset(inps, tgts)

  for i, (x, y) in enumerate(training_loader):
    pred = play_model(x)
    loss = loss_fn(pred, y)
    print([(name, "grad:", param.grad) for name, param in play_model.named_parameters()])