Detecting residual connections

Hi! I apologize in advance if this question is repeated, couldn’t find any similar queries.

Is it possible to detect skip connections in any model ? For example if a block in the model contains the below function as their forward pass.

def forward(self, x):
    out = self.layers(x)
    return x + out

Can I somehow detect that a particular set of block have this kind of forward pass with residual connection for any given model? I’m looking for any type of solution - using hooks, or graph or anything else ?

TIA

Hi,

In general I can’t think of a way to do this.
Can you give a bit more context on why you want to do that?

I was looking for a solution for pruning. If I prune layer_1 whose output say has a connection with layer_3, I wanted to identify this connection and prune layer_3 accordingly. Pruning here is the context of channel/filter pruning.

Kind of a rought way to detect residual connection could be -

def forward(self, x):
    self.out = self.layers(x)
    return x + self.out

outputs = model(features)
if (outputs - model.out).abs().mean() == model.layers(features).abs().mean():
   print('there is a residual connection')

Again not a reccomended way but should work in this particular example.

whilst looking at output.grad_fn.next_functions on autograd
I do see some “Addbackward” gradient functions which is presumably residual connections (and even though it is not, I do think we should keep the channel dimensions for these tensors since we are adding. plz point me out if it is wrong)
however, it seems this algorithm is quite daunting.
My colleague who is highly skilled had quite a trouble making an algorithm to find these residual connections and it only works for resnet. (tested on mobilenet, efficientnet (timm), and so on)

I would love to see if there are any ideas or code implemented for this.
Seems like knowing where the residual connection exists in the first place is the best attempt for this

1 Like