Hi! I apologize in advance if this question is repeated, couldn’t find any similar queries.
Is it possible to detect skip connections in any model ? For example if a block in the model contains the below function as their forward pass.
def forward(self, x):
out = self.layers(x)
return x + out
Can I somehow detect that a particular set of block have this kind of forward pass with residual connection for any given model? I’m looking for any type of solution - using hooks, or graph or anything else ?
I was looking for a solution for pruning. If I prune layer_1 whose output say has a connection with layer_3, I wanted to identify this connection and prune layer_3 accordingly. Pruning here is the context of channel/filter pruning.
whilst looking at output.grad_fn.next_functions on autograd
I do see some “Addbackward” gradient functions which is presumably residual connections (and even though it is not, I do think we should keep the channel dimensions for these tensors since we are adding. plz point me out if it is wrong)
however, it seems this algorithm is quite daunting.
My colleague who is highly skilled had quite a trouble making an algorithm to find these residual connections and it only works for resnet. (tested on mobilenet, efficientnet (timm), and so on)
I would love to see if there are any ideas or code implemented for this.
Seems like knowing where the residual connection exists in the first place is the best attempt for this