Tracing a graph of layers in a model

Trying to implement structured filter pruning from scratch. Is there anyway to get the dependancy graph of layers in a model?

I will need a graph because if i remove a filter from a nn.conv2d layer, I will need to modify the structure of all downstream layers which are directly dependant on the layer which I am pruning.

Any straightforward way to do this in pytorch?

@Vedant_Dalimkar The following post can be helpful

This didn’t help me.

Let’s say I want to remove a filter from a nn.conv2d layer, it has ‘out_channels’ number of filters each having ‘in_channels’ channel dimension. Now if I remove a filter from this conv2d layer, I will need to change the ‘in_channels’ argument for all the conv2d layers (might be more than 1 downstream dependant layer due to residual connections) that are dependant on this ‘changed’ conv2d layer.

Hope that makes my question clearer.

@Vedant_Dalimkar Understood
What you want, that needs to be coded up

  1. Create a new conv2d layer and copy the parameters (weights and biases) from the old layer, while excluding the filters that you want to prune
  2. Create new layers before and after this layer that ensures that upstream and downstream layers are not impacted
  3. Replace the old conv2d layer with this new Sequential

All existing pruning libraries, will mask the weights but not completely remove the filter
e.g.
https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.ln_structured.html

At runtime we don’t know which filter in which layer will be pruned so that is why having a dependency graph is a must here.

For example take the last conv layer in the contracting portion of Unet. Its output will be passed to the next layer as well as to a layer in the expanding portion of Unet. Your approach won’t work here as the pruned layer and all possible downstream layer are not in a single nn.Sequential.

@Vedant_Dalimkar okay. Maybe i am not able to understand why do you want to touch other layers of a model apart from the layer that you are pruning

With regards to seeing the dependency graph, you can refer to this post

It has multiple examples with multiple libraries

Happy coding

Read this again, let me know if it’s still unclear.