Forward_pre_hooks not called after applying nn.uitls.prune methods

Hello,

I am working with the newly released pruning functionalities in torch.nn.utils.prune and I am working on extending this implementation of the MS-D network:

This is a network with densely connected 3x3 convolutions followed by a final layer of 1x1 convolutions. Some simplified code:

import msd_pytorch
import torch.nn.utils.prune as prune

def pytorch_L1_prune(model, depth, perc):
    for k in range(depth):
        wname ="weight"+str(k)
        prune.l1_unstructured(model.msd.msd_block, name=wname, amount=perc)
    prune.l1_unstructured(model.msd.final_layer.linear, name='weight', amount=perc)
    return model

d = 100
model = msd_pytorch.MSDSegmentationModel(c_in, num_labels, d, width, dilations=dilations)
# Load model here
model = pytorch_L1_prune(model, d, 0.3)

What I have noticed:
The network works fine but I have noticed that the pruning has no effect on the 3x3 convolutions but it does on the 1x1 final layer. My hypothesis is that the forward_pre_hooks are not applied for the forward() pass of the 3x3 convolutions.

Hypothesis:
I found the following note in the docs (torch.nn.modules.module):

    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.

I don’t really understand what this note means or how I should resolve it which is why I thought it might be good to clear it up here.

The network has an MSDModule2D which has 2 types of modules; MSDBlock2D for the 3x3 convolutions and MSDFinalLayer at the end and the forward pass is as follows:

class MSDModule2d(torch.nn.Module):
    def __init__(self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]):
        super(MSDModule2d, self).__init__()
        ... # initialize dilations etc.
        self.msd_block = MSDBlock2d(self.c_in, self.dilations, self.width)
        self.final_layer = MSDFinalLayer(c_in=c_in + width * depth, c_out=c_out)
        ...

    def forward(self, input):
        output = self.msd_block(input)
        output = self.final_layer(output)
        return output

The latter basically uses nn.Conv1D() so the above note probably does not apply and it would therefore make sense that forward_pre_hooks are applied. The MSDBlock2D module implements a somewhat complicated forward pass that I do not want to fully post here:

class MSDBlockImpl2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, dilations, bias, *weights):
          ... # complicated

    @staticmethod
    def backward(ctx, grad_output):
          ... # complicated

class MSDBlock2d(torch.nn.Module):
    def __init__(self, in_channels, dilations, width=1):
        super().__init__()
        ... # initialize weights etc.

    def forward(self, input):
        bias, *weights = self.parameters()
        return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights)

Any ideas what is going on? Is the above hypothesis correct and if so, what does it mean and what should I add to the MSDBlock2D module?

Kind regards,

Richard

EDIT: CHECK ON PRUNING METHODS
I checked whether the pruning method does its job.

print(list(model.msd.final_layer.linear.named_buffers()))
print(list(model.msd.msd_block.named_buffers()))

The model does have the masks etc. Furthermore, the following

print(model.msd.final_layer.linear._forward_pre_hooks)
print(model.msd.msd_block._forward_pre_hooks)

shows that the hooks are initialized. They are just not applied in the msd_block case it seems.

cc @Michela who is the original author of that module.

I’m not the author of that quote, but that states that you should always call module() as opposed to module.forward() if you don’t want to have the hooks ignored. Does the code make an explicit call to forward() by chance? Quickly skimming through it, I found a forward() call here: https://github.com/ahendriksen/msd_pytorch/blob/master/msd_pytorch/msd_model.py#L213 and here: https://github.com/ahendriksen/msd_pytorch/blob/master/msd_pytorch/msd_model.py#L250
Not sure how and if these are used at all anywhere in your code. BTW, thanks for finding that quote in the docs and doing lots of research into this yourself before raising this issue.

Another potential problem I see is that in the MSDBlock2d’s forward (https://github.com/ahendriksen/msd_pytorch/blob/master/msd_pytorch/msd_block.py#L179-L180) you grab the weights from self.parameters() and then pass them to the complicated forward() above. However, the parameters store the unpruned version of the weight tensors. The pruned versions are just attributes attached to the corresponding Modules. So this might also be the cause of the problem you’re seeing.

Other than these quick observations, from the details you provided (i.e., the fact that the hooks do exist and the net is indeed in a pruned state), it sound as if this is not an issue with pruning, but rather an issue with the specific model you’re using and the way it computes the forward pass, which is probably causing it to ignore the hooks for certain parts of the net. Not sure I’m the best person to debug that but I can take a closer look if needed.

1 Like

Hello,

Thanks for the elaborate answer and for taking the time to look into the code. The lines that I had to changes were indeed:

# In msd_block.py lines 179-180, in the forward pass method
bias, *weights = self.parameters()

# CHANGED TO:
bias = self.bias
weights = (self.__dict__['weight{}'.format(i)] for i in range(len(self.weights)))

So is it true that when we change the names of the parameters to weights_orig the “pointers” stored in parameters() point to the originals, regardless of the name change? The changed line of code seems to work as pruning is now applied but it does not look very nice. Is this the way it should be done in your opinion?

Kind regards,

Richard

PS: I found the following curious change. The following line runs fine for PyTorch 1.1.0:

# This runs fine when I run the network (unpruned) in 1.1.0
weights = (self.__getattr__("weight{}".format(i)) for i in range(len(self.weights)))

# But throws the following Error for the pruned network in 1.4.0
AttributeError: 'MSDBlock2d' object has no attribute 'weight0'

It is probably something minor but I found it curious and therefore changed it to the above. When you print the __dir__() of the object you do see the attributes weight, weight_mask, weight_orig etc.

Glad we figured out how to make it work!
Yes, the parameters will always be the unpruned versions of the tensors, until you .remove() the reparametrization.

One option: grab both the unpruned parameters from self.parameters() and the masks from self.buffers() and do the multiplication manually (con: would require a new logic block to be executed only if the mask buffers exist).

Another option: without directly accessing __dict__, do something like: weights = [module.weight for module in self.modules() if hasattr(module, 'weight')], or whatever this needs to be for your model to work. In your specific case, I guess that, since you don’t have single tensors called 'weight' in a variety of modules but a single module with a variety of tensors called 'weight{i}', then it would look something like (getattr(self, 'weight{}'.format(i)) for i in range(len(self.weights))), instead. Should work fine for both pruned and unpruned version of the model.

Re this:

you should do (getattr(self, "weight{}".format(i)) for i in range(len(self.weights))) instead of __getattr__. See this for reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L580

1 Like