I’m currently working on a project where I have a model that includes residual connections, and I’d like to customize the behavior of these connections using a custom function. Specifically, I want to replace the usual `res = res + conv(x)`

operation with `res = f(res) + conv(x)`

, where `f`

is a user-defined function.

Here’s the current implementation of my model:

```
class CONV(nn.Module):
def __init__(self,input_channels,output_channels):
super(CONV, self).__init__()
self.conv1= nn.Conv2d(input_channels, 32, 3,bias=False)
self.batch_norm1 = nn.BatchNorm2d(32)
self.conv2= nn.Conv2d(32, 32, 3,bias=False)
self.conv3= nn.Conv2d(32, 32, 3,bias=False,padding=2)
self.maxpool = nn.MaxPool2d(3, stride=2)
self.conv4= nn.Conv2d(32, output_channels,3,bias=False)
def forward(self, x):
x = self.conv1(x)
x_res = x
x=self.batch_norm1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = x_res+self.conv3(x)
x = nn.functional.relu(x)
x=self.maxpool(x)
x = self.conv4(x)
return x
```

After instantiation :

`model = CONV(3, 3)`

I’d like to write a function `residual(model, f)`

that modifies the residual connections in the model after its creation(by modifying named_modules …) because the function `f`

is evolving over the time after each epoch…. How can I do this?

Any help would be greatly appreciated. Thank you!