How to concatenate two layers using sefl.add_module?

Hello all, I have a network architecture as follows:


input --> conv1 (3,1,1) --> bn --> relu --> conv2 (3,1,1)
                  |                               ^
                  |-------------------------------|

where conv1(3,1,1) means kernel size is 3, stride 1 and padding 1

The output of conv1 will be concatenate with the output of conv2. It is easy to using torch.cat function to concatenate them. However, I am using the function self.add_module to write the network. How could I use concatenation function on the case? This is my sample code

self.add_module('conv1', nn.Conv3d(32, 64, kernel_size=3, stride=1))
self.add_module('conv1_norm', nn.BatchNorm3d(64))
self.add_module('conv1_relu', nn.ReLU(inplace=True))
self.add_module('conv2', nn.Conv3d(64, 128, kernel_size=3, stride=1))
# Concatenate ???

Thanks

Hi,

I Guess you do this into a Sequential Module? the sequential module can only perform sequential operations (not two branch and merge the results). You will need to implement the forward function for this yourself I’m afraid.
For example, you can wrap these 4 ops in a custom nn.Module with your forward that concatenates the outputs and then add this new module to your Sequential one.

1 Like

Yes. Could you give me an example code? I am writing it in init function of class

I guess it’s like a resnet module:

class ConcatMod(nn.Module):
    def __init__(self, args):
        # Use args properly to set conv params here depending on your application
        self.conv1 = nn.Conv3d(32, 64, kernel_size=3, stride=1)
        self.conv1_norm = nn.BatchNorm3d(64)
        self.conv1_relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=3, stride=1)

    def forward(self, input):
        out1 = self.conv1(input)
        out1_renormed = self.conv1_relu(self.conv1_norm(out1))
        out2 = self.conv2(out1_renormed)
        output = torch.cat([out1, out2], 1)
        return output

# In your Sequential, you can do
self.add_module("concat1", ConcatMod(args))
2 Likes

Thanks. I think I have to rewrite the class, instead of using self.add_module