How to automatically detect residual connections in a model and apply user-defined function?

I’m currently working on a project where I have a model that includes residual connections, and I’d like to customize the behavior of these connections using a custom function. Specifically, I want to replace the usual res = res + conv(x) operation with res = f(res) + conv(x), where f is a user-defined function.

Here’s the current implementation of my model:

class CONV(nn.Module):

    def __init__(self,input_channels,output_channels):

        super(CONV, self).__init__()
        self.conv1= nn.Conv2d(input_channels, 32, 3,bias=False)
        self.batch_norm1 = nn.BatchNorm2d(32)

        self.conv2= nn.Conv2d(32, 32, 3,bias=False)
        self.conv3= nn.Conv2d(32, 32, 3,bias=False,padding=2)
        self.maxpool = nn.MaxPool2d(3, stride=2)
        self.conv4= nn.Conv2d(32, output_channels,3,bias=False)

    def forward(self, x):

        x = self.conv1(x)
        x_res = x

        x=self.batch_norm1(x)

        x = nn.functional.relu(x)


        x = self.conv2(x)
        x = nn.functional.relu(x)

        x = x_res+self.conv3(x)
        x = nn.functional.relu(x)
 
        x=self.maxpool(x)
        x = self.conv4(x)

        return x

After instantiation :
model = CONV(3, 3)

I’d like to write a function residual(model, f) that modifies the residual connections in the model after its creation(by modifying named_modules …) because the function f is evolving over the time after each epoch…. How can I do this?

Any help would be greatly appreciated. Thank you!

Why not just instantiate the model with x = f(x_res) + self.conv3(x) ? I’m not understanding the use case. You can even rewrite your model init to include a boolean flag ‘use_f’ then code
if self.use_f:
x = f(x_res) + self.conv3(x)
else:
x = x_res + self.conv3(x)

Then initialize two models for comparison.
model = CONV(3, 3, False)
model_f = CONV(3, 3, True)
Is there some reason this simpler idea won’t work for you?

@absolved I want to apply the function dynamically after the creation of the model, because the function f is evolving over the time (after each epoch)…and I’m not the one who creates the model class, but I use the model given by a user.

@ptrblck is it feasible?