I’ve been trying out pytorch for a while and have a somewhat contrived used case: I am trying to change the shape of the weight tensor inside Conv2d layers in order to do some filter pruning on a pre-trained model:
I wrote some code to change the shape of all the conv layers in the model.
for layer in model.modules():
if isinstance(layer, nn.Conv2d):
print("Pruning Conv-{}".format(index))
filters = layer.weight.data.numpy()
new_filters = prune_layer(filters) # reshape the tensor
# Update layer with new tensor
layer.weight.data = torch.from_numpy(new_filters).float()
layer.out_channels = new_filters.shape[0]
The problem I have is that since the out_channels of the ith layer change, I also need to adjust the in_channels of the
ith+1 layer and from modules() I can’t guarantee that the order of the nn.Conv2d modules is the same as the one defined in the forward function, so I might end up changing the channels for the wrong layers.
Is there a better approach to this? Am I doing it completely wrong?
If you don’t have skip-layer connections, the order should usually be the same. Just note that you need to prune weight and bias in Conv2d and BatchNorm2d (if you use them).
If you have skip-layer connections, you can build a dependency graph about who passes output to who. You can then use the dependency graph to trace which convs take input from which convs. One example that builds such a graph can be found here (look up pytorch_visualize.py). This example performs graph visualization.
A much simpler way, however, is to “hijack” the forward function of Conv2d, BatchNorm2d, activations functions and so on, like so:
The hijacking approach looks nice but I still need to keep track of dependencies between layers right?
In order to update the in_channels of some of them, I mean.