How to modify the module's inputs in computation graph after the forward

Hi, I am reproducing a research work about model pruning in generating adversarial examples. The paper link is: [2208.08677] Enhancing Targeted Attack Transferability via Diversified Weight Pruning. In Section 3.2, the description of model pruning is as follows,

It means we use an original network for forward and its pruned network for backward to calculate the gradient.

The key to pruning the network before gradient computing in backpropagation is to modify the module’s inputs (containing the model’s weights) in the computation graph after the forward.

However, existing hook methods can not solve the problem,

  • register_module_forward_pre_hook can modify the input before the forward, but we want to do it after the forward.
  • register_module_forward_hook can only modify the output instead of input after the forward.
  • backward_hook can modify the gradient, but we want to modify the inputs (containing the model’s weights) in the computation graph

I wonder whether there is any way to modify the input in the computation graph after the forward.

I’m not sure what you mean by use a pruned network for backward, but if you want to prune saved activations before using them to compute the backward pass, you could use Hooks for autograd saved tensors — PyTorch Tutorials 2.0.0+cu117 documentation. With these hooks, the gradients would be computed as if the forward were computed using the pruned inputs. It should be noted that sometimes, the forward would save the output instead of input.

1 Like

Thanks a lot. Now I know I can use unpack_hook to get the save_tensors in the backward
For my pruning task, for example, in a linear layer, y = W*x, W and x are packed and saved into the save_tensors before the forward. I can use unpack_hook to unpack W and multiply W with a mask matrix to implement pruning to calculate the gradient.
I am trying this idea and will post the code if it works

The code is here, replace the Conv2d module with Conv2dWithHooks

class Conv2dWithHooks(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, input):
        with torch.autograd.graph.saved_tensors_hooks(self.pack, self.unpack):
            return super().forward(input)

    def pack(self, x): # nothing to modify in the forward pass
        return x

    def unpack(self, x): # modify weight before the backward pass
        if x.shape == self.weight.shape: # x is weight
            weight_mask = torch.ones_like(self.weight)   # example mask
            x = x * weight_mask
        return x