How to automatically remove weights from network after pruning?

Hi all,

I try to implement simple iterative pruning using pytorch and I have one question:
If I want to prune some channels from some layer, how can I automaticaly prune of incident layers (BatchNorm, next convolution layer or etc.) without information about graph topology?

For example there are simple network:
nn.Conv2d(3, 10, kernel_size=1)->nn.BatchNorm(10)->nn.Conv2d(10, 20, kernel_size=1)
If I prune one channel from first convolution, I should prune one layer from BatchNorm and second convolution consistency. But I’m not sure how can I do it automatically…
Thanks in advance!

hi @SerB , I am trying to the same. Have you been able to figure it out?