I’m currently working on a project where I have a model that includes residual connections, and I’d like to customize the behavior of these connections using a custom function. Specifically, I want to replace the usual res = res + conv(x)
operation with res = f(res) + conv(x)
, where f
is a user-defined function.
Here’s the current implementation of my model:
class CONV(nn.Module):
def __init__(self,input_channels,output_channels):
super(CONV, self).__init__()
self.conv1= nn.Conv2d(input_channels, 32, 3,bias=False)
self.batch_norm1 = nn.BatchNorm2d(32)
self.conv2= nn.Conv2d(32, 32, 3,bias=False)
self.conv3= nn.Conv2d(32, 32, 3,bias=False,padding=2)
self.maxpool = nn.MaxPool2d(3, stride=2)
self.conv4= nn.Conv2d(32, output_channels,3,bias=False)
def forward(self, x):
x = self.conv1(x)
x_res = x
x=self.batch_norm1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = x_res+self.conv3(x)
x = nn.functional.relu(x)
x=self.maxpool(x)
x = self.conv4(x)
return x
After instantiation :
model = CONV(3, 3)
I’d like to write a function residual(model, f)
that modifies the residual connections in the model after its creation(by modifying named_modules …) because the function f
is evolving over the time after each epoch…. How can I do this?
Any help would be greatly appreciated. Thank you!