My output model is not the same order as I anticipated

Hello, I found that when I run this simple model structure, the output from torchsummary is double batchnormalization and double relu after two convolution layer. Can anyone help me to fix the order to
conv, batchnormalization, relu? Thanks

import torch
import torch.nn as nn
from torchsummary import summary

class ConvBlock(nn.Module):
    def __init__(self, in_channels, repeat_conv_num=2, transpose=False, **kwargs):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, **kwargs) \
            if not transpose else nn.ConvTranspose2d(in_channels, in_channels, **kwargs)
        self.conv2 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=4, stride=2, padding=1) \
            if not transpose else nn.ConvTranspose2d(in_channels, int(in_channels / 2), kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(in_channels*2) if not transpose else nn.BatchNorm2d(int(in_channels/2))
        self.relu = nn.ReLU(inplace=True)
        self.block = nn.ModuleList([])
        for _ in range(repeat_conv_num):
            self.block.extend([self.conv1, self.bn1, self.relu])
        self.block.extend([self.conv2, self.bn2, self.relu])


    def forward(self, x):
        for layer in self.block:
            print(layer)
            x = layer(x)
        return x


class AutoEncoder(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=32, **kwargs):
        super(AutoEncoder, self).__init__()
        self.init_layer = nn.Conv2d(in_channels, features, kernel_size=3, stride=1, padding=1)
        self.last_layer = nn.Conv2d(features, out_channels, kernel_size=3, stride=1, padding=1)
        self.model = nn.Sequential(
            self.init_layer,
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),
            ConvBlock(in_channels=features, **kwargs),
            ConvBlock(in_channels=features*2, **kwargs),
            ConvBlock(in_channels=features*4, **kwargs),
            ConvBlock(in_channels=features*8, **kwargs),
            ConvBlock(in_channels=features*16, transpose=True, **kwargs),
            ConvBlock(in_channels=features * 8, transpose=True, **kwargs),
            ConvBlock(in_channels=features * 4, transpose=True, **kwargs),
            ConvBlock(in_channels=features * 2, transpose=True, **kwargs),
            self.last_layer,
            nn.Sigmoid()
        )


    def forward(self, x):
        return self.model(x)


if __name__ == "__main__":
    model = AutoEncoder(in_channels=1, out_channels=1, features=32, kernel_size=3, stride=1, padding=1, repeat_conv_num=2).to('cuda')
    x = torch.randn((2, 1, 512, 512)).to('cuda')
    summary(model, input_size=(1, 512, 512))
    print(model(x).shape)

    model2 = ConvBlock(in_channels=32, kernel_size=3, stride=1, padding=1, repeat_conv_num=2).to('cuda')
    x = torch.randn((2, 32, 512, 512)).to('cuda')
    summary(model2, input_size=(32, 512, 512))
    print(model2(x).shape)

result:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
            Conv2d-2           [-1, 32, 64, 64]             320
       BatchNorm2d-3           [-1, 32, 64, 64]              64
              ReLU-4           [-1, 32, 64, 64]               0
            Conv2d-5           [-1, 32, 64, 64]           9,248
            Conv2d-6           [-1, 32, 64, 64]           9,248
       BatchNorm2d-7           [-1, 32, 64, 64]              64
       BatchNorm2d-8           [-1, 32, 64, 64]              64
              ReLU-9           [-1, 32, 64, 64]               0
             ReLU-10           [-1, 32, 64, 64]               0
           Conv2d-11           [-1, 32, 64, 64]           9,248
           Conv2d-12           [-1, 32, 64, 64]           9,248
      BatchNorm2d-13           [-1, 32, 64, 64]              64
      BatchNorm2d-14           [-1, 32, 64, 64]              64
             ReLU-15           [-1, 32, 64, 64]               0
             ReLU-16           [-1, 32, 64, 64]               0
           Conv2d-17           [-1, 64, 32, 32]          32,832
           Conv2d-18           [-1, 64, 32, 32]          32,832
      BatchNorm2d-19           [-1, 64, 32, 32]             128
      BatchNorm2d-20           [-1, 64, 32, 32]             128
             ReLU-21           [-1, 64, 32, 32]               0
             ReLU-22           [-1, 64, 32, 32]               0
        ConvBlock-23           [-1, 64, 32, 32]               0
           Conv2d-24           [-1, 64, 32, 32]          36,928
           Conv2d-25           [-1, 64, 32, 32]          36,928
      BatchNorm2d-26           [-1, 64, 32, 32]             128
      BatchNorm2d-27           [-1, 64, 32, 32]             128
             ReLU-28           [-1, 64, 32, 32]               0
             ReLU-29           [-1, 64, 32, 32]               0
           Conv2d-30           [-1, 64, 32, 32]          36,928
           Conv2d-31           [-1, 64, 32, 32]          36,928
      BatchNorm2d-32           [-1, 64, 32, 32]             128
      BatchNorm2d-33           [-1, 64, 32, 32]             128
             ReLU-34           [-1, 64, 32, 32]               0
             ReLU-35           [-1, 64, 32, 32]               0
           Conv2d-36          [-1, 128, 16, 16]         131,200
           Conv2d-37          [-1, 128, 16, 16]         131,200
      BatchNorm2d-38          [-1, 128, 16, 16]             256
      BatchNorm2d-39          [-1, 128, 16, 16]             256
             ReLU-40          [-1, 128, 16, 16]               0
             ReLU-41          [-1, 128, 16, 16]               0
        ConvBlock-42          [-1, 128, 16, 16]               0
           Conv2d-43          [-1, 128, 16, 16]         147,584
           Conv2d-44          [-1, 128, 16, 16]         147,584
      BatchNorm2d-45          [-1, 128, 16, 16]             256
      BatchNorm2d-46          [-1, 128, 16, 16]             256
             ReLU-47          [-1, 128, 16, 16]               0
             ReLU-48          [-1, 128, 16, 16]               0
           Conv2d-49          [-1, 128, 16, 16]         147,584
           Conv2d-50          [-1, 128, 16, 16]         147,584
      BatchNorm2d-51          [-1, 128, 16, 16]             256
      BatchNorm2d-52          [-1, 128, 16, 16]             256
             ReLU-53          [-1, 128, 16, 16]               0
             ReLU-54          [-1, 128, 16, 16]               0
           Conv2d-55            [-1, 256, 8, 8]         524,544
           Conv2d-56            [-1, 256, 8, 8]         524,544
      BatchNorm2d-57            [-1, 256, 8, 8]             512
      BatchNorm2d-58            [-1, 256, 8, 8]             512
             ReLU-59            [-1, 256, 8, 8]               0
             ReLU-60            [-1, 256, 8, 8]               0
        ConvBlock-61            [-1, 256, 8, 8]               0
           Conv2d-62            [-1, 256, 8, 8]         590,080
           Conv2d-63            [-1, 256, 8, 8]         590,080
      BatchNorm2d-64            [-1, 256, 8, 8]             512
      BatchNorm2d-65            [-1, 256, 8, 8]             512
             ReLU-66            [-1, 256, 8, 8]               0
             ReLU-67            [-1, 256, 8, 8]               0
           Conv2d-68            [-1, 256, 8, 8]         590,080
           Conv2d-69            [-1, 256, 8, 8]         590,080
      BatchNorm2d-70            [-1, 256, 8, 8]             512
      BatchNorm2d-71            [-1, 256, 8, 8]             512
             ReLU-72            [-1, 256, 8, 8]               0
             ReLU-73            [-1, 256, 8, 8]               0
           Conv2d-74            [-1, 512, 4, 4]       2,097,664
           Conv2d-75            [-1, 512, 4, 4]       2,097,664
      BatchNorm2d-76            [-1, 512, 4, 4]           1,024
      BatchNorm2d-77            [-1, 512, 4, 4]           1,024
             ReLU-78            [-1, 512, 4, 4]               0
             ReLU-79            [-1, 512, 4, 4]               0
        ConvBlock-80            [-1, 512, 4, 4]               0
  ConvTranspose2d-81            [-1, 512, 4, 4]       2,359,808
  ConvTranspose2d-82            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-83            [-1, 512, 4, 4]           1,024
      BatchNorm2d-84            [-1, 512, 4, 4]           1,024
             ReLU-85            [-1, 512, 4, 4]               0
             ReLU-86            [-1, 512, 4, 4]               0
  ConvTranspose2d-87            [-1, 512, 4, 4]       2,359,808
  ConvTranspose2d-88            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-89            [-1, 512, 4, 4]           1,024
      BatchNorm2d-90            [-1, 512, 4, 4]           1,024
             ReLU-91            [-1, 512, 4, 4]               0
             ReLU-92            [-1, 512, 4, 4]               0
  ConvTranspose2d-93            [-1, 256, 8, 8]       2,097,408
  ConvTranspose2d-94            [-1, 256, 8, 8]       2,097,408
      BatchNorm2d-95            [-1, 256, 8, 8]             512
      BatchNorm2d-96            [-1, 256, 8, 8]             512
             ReLU-97            [-1, 256, 8, 8]               0
             ReLU-98            [-1, 256, 8, 8]               0
        ConvBlock-99            [-1, 256, 8, 8]               0
 ConvTranspose2d-100            [-1, 256, 8, 8]         590,080
 ConvTranspose2d-101            [-1, 256, 8, 8]         590,080
     BatchNorm2d-102            [-1, 256, 8, 8]             512
     BatchNorm2d-103            [-1, 256, 8, 8]             512
            ReLU-104            [-1, 256, 8, 8]               0
            ReLU-105            [-1, 256, 8, 8]               0
 ConvTranspose2d-106            [-1, 256, 8, 8]         590,080
 ConvTranspose2d-107            [-1, 256, 8, 8]         590,080
     BatchNorm2d-108            [-1, 256, 8, 8]             512
     BatchNorm2d-109            [-1, 256, 8, 8]             512
            ReLU-110            [-1, 256, 8, 8]               0
            ReLU-111            [-1, 256, 8, 8]               0
 ConvTranspose2d-112          [-1, 128, 16, 16]         524,416
 ConvTranspose2d-113          [-1, 128, 16, 16]         524,416
     BatchNorm2d-114          [-1, 128, 16, 16]             256
     BatchNorm2d-115          [-1, 128, 16, 16]             256
            ReLU-116          [-1, 128, 16, 16]               0
            ReLU-117          [-1, 128, 16, 16]               0
       ConvBlock-118          [-1, 128, 16, 16]               0
 ConvTranspose2d-119          [-1, 128, 16, 16]         147,584
 ConvTranspose2d-120          [-1, 128, 16, 16]         147,584
     BatchNorm2d-121          [-1, 128, 16, 16]             256
     BatchNorm2d-122          [-1, 128, 16, 16]             256
            ReLU-123          [-1, 128, 16, 16]               0
            ReLU-124          [-1, 128, 16, 16]               0
 ConvTranspose2d-125          [-1, 128, 16, 16]         147,584
 ConvTranspose2d-126          [-1, 128, 16, 16]         147,584
     BatchNorm2d-127          [-1, 128, 16, 16]             256
     BatchNorm2d-128          [-1, 128, 16, 16]             256
            ReLU-129          [-1, 128, 16, 16]               0
            ReLU-130          [-1, 128, 16, 16]               0
 ConvTranspose2d-131           [-1, 64, 32, 32]         131,136
 ConvTranspose2d-132           [-1, 64, 32, 32]         131,136
     BatchNorm2d-133           [-1, 64, 32, 32]             128
     BatchNorm2d-134           [-1, 64, 32, 32]             128
            ReLU-135           [-1, 64, 32, 32]               0
            ReLU-136           [-1, 64, 32, 32]               0
       ConvBlock-137           [-1, 64, 32, 32]               0
 ConvTranspose2d-138           [-1, 64, 32, 32]          36,928
 ConvTranspose2d-139           [-1, 64, 32, 32]          36,928
     BatchNorm2d-140           [-1, 64, 32, 32]             128
     BatchNorm2d-141           [-1, 64, 32, 32]             128
            ReLU-142           [-1, 64, 32, 32]               0
            ReLU-143           [-1, 64, 32, 32]               0
 ConvTranspose2d-144           [-1, 64, 32, 32]          36,928
 ConvTranspose2d-145           [-1, 64, 32, 32]          36,928
     BatchNorm2d-146           [-1, 64, 32, 32]             128
     BatchNorm2d-147           [-1, 64, 32, 32]             128
            ReLU-148           [-1, 64, 32, 32]               0
            ReLU-149           [-1, 64, 32, 32]               0
 ConvTranspose2d-150           [-1, 32, 64, 64]          32,800
 ConvTranspose2d-151           [-1, 32, 64, 64]          32,800
     BatchNorm2d-152           [-1, 32, 64, 64]              64
     BatchNorm2d-153           [-1, 32, 64, 64]              64
            ReLU-154           [-1, 32, 64, 64]               0
            ReLU-155           [-1, 32, 64, 64]               0
       ConvBlock-156           [-1, 32, 64, 64]               0
          Conv2d-157            [-1, 1, 64, 64]             289
          Conv2d-158            [-1, 1, 64, 64]             289
         Sigmoid-159            [-1, 1, 64, 64]               0
================================================================
Total params: 26,835,522
Trainable params: 26,835,522
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 57.53
Params size (MB): 102.37
Estimated Total Size (MB): 159.92
----------------------------------------------------------------
torch.Size([2, 1, 64, 64])

Can anyone help me this problem? I’m hurry, please. Thank you guys a lot!

You are assigning the modules directly first (e.g. via self.bn1 = nn.BatchNorm2d(in_channels)) and later again in the self.block nn.ModuleList. Since you are only using self.block, you could remove the self.bn1 = ... assignments, use an assignment without the self keyword (bn1 = ...), and add these layers to self.block afterwards.

@ptrblck Sorry, I don’t know where to correct the code and where I misunderstanding, I delete self statement but still failed.

import torch
import torch.nn as nn
from torchsummary import summary

class ConvBlock(nn.Module):
    def __init__(self, in_channels, repeat_conv_num=2, transpose=False, **kwargs):
        super(ConvBlock, self).__init__()
        self.repeat_conv_num = repeat_conv_num
        conv1 = nn.Conv2d(in_channels, in_channels, **kwargs) \
            if not transpose else nn.ConvTranspose2d(in_channels, in_channels, **kwargs)
        conv2 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=4, stride=2, padding=1) \
            if not transpose else nn.ConvTranspose2d(in_channels, int(in_channels / 2), kernel_size=4, stride=2, padding=1)
        bn1 = nn.BatchNorm2d(in_channels)
        bn2 = nn.BatchNorm2d(in_channels*2) if not transpose else nn.BatchNorm2d(int(in_channels/2))
        relu = nn.ReLU(inplace=True)
        self.block = nn.ModuleList([])
        for _ in range(repeat_conv_num):
            self.block.extend([conv1, bn1, relu])
        self.block.extend([conv2, bn2, relu])
        self.conv_block = nn.Sequential(*self.block)
        # print(self.conv_block)
        # for _ in self.block:
        #     print(_)

    def forward(self, x):
        print(self.conv_block)
        return self.conv_block(x)



class AutoEncoder(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=32, **kwargs):
        super(AutoEncoder, self).__init__()
        self.init_layer = nn.Conv2d(in_channels, features, kernel_size=3, stride=1, padding=1)
        self.last_layer = nn.Conv2d(features, out_channels, kernel_size=3, stride=1, padding=1)
        self.model = nn.Sequential(
            self.init_layer,
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),
            ConvBlock(in_channels=features, **kwargs),
            ConvBlock(in_channels=features*2, **kwargs),
            ConvBlock(in_channels=features*4, **kwargs),
            ConvBlock(in_channels=features*8, **kwargs),
            ConvBlock(in_channels=features*16, transpose=True, **kwargs),
            ConvBlock(in_channels=features * 8, transpose=True, **kwargs),
            ConvBlock(in_channels=features * 4, transpose=True, **kwargs),
            ConvBlock(in_channels=features * 2, transpose=True, **kwargs),
            self.last_layer,
            nn.Sigmoid()
        )


    def forward(self, x):
        return self.model(x)


if __name__ == "__main__":
    model = AutoEncoder(in_channels=1, out_channels=1, features=32, kernel_size=3, stride=1, padding=1, repeat_conv_num=2).to('cuda')
    x = torch.randn((2, 1, 64, 64)).to('cuda')
    summary(model, input_size=(1, 64, 64))
    print(model(x).shape)

    cblock = ConvBlock(in_channels=32, repeat_conv_num=2, transpose=False, kernel_size=3, stride=1, padding=1).to('cuda')
    x = torch.randn((2, 32, 64, 64)).to('cuda')
    summary(cblock, input_size=(32, 64, 64))
    print(cblock(x).shape)