Hello, I found that when I run this simple model structure, the output from torchsummary
is double batchnormalization and double relu after two convolution layer. Can anyone help me to fix the order to
conv, batchnormalization, relu? Thanks
import torch
import torch.nn as nn
from torchsummary import summary
class ConvBlock(nn.Module):
def __init__(self, in_channels, repeat_conv_num=2, transpose=False, **kwargs):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, **kwargs) \
if not transpose else nn.ConvTranspose2d(in_channels, in_channels, **kwargs)
self.conv2 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=4, stride=2, padding=1) \
if not transpose else nn.ConvTranspose2d(in_channels, int(in_channels / 2), kernel_size=4, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(in_channels*2) if not transpose else nn.BatchNorm2d(int(in_channels/2))
self.relu = nn.ReLU(inplace=True)
self.block = nn.ModuleList([])
for _ in range(repeat_conv_num):
self.block.extend([self.conv1, self.bn1, self.relu])
self.block.extend([self.conv2, self.bn2, self.relu])
def forward(self, x):
for layer in self.block:
print(layer)
x = layer(x)
return x
class AutoEncoder(nn.Module):
def __init__(self, in_channels=1, out_channels=1, features=32, **kwargs):
super(AutoEncoder, self).__init__()
self.init_layer = nn.Conv2d(in_channels, features, kernel_size=3, stride=1, padding=1)
self.last_layer = nn.Conv2d(features, out_channels, kernel_size=3, stride=1, padding=1)
self.model = nn.Sequential(
self.init_layer,
nn.BatchNorm2d(features),
nn.ReLU(inplace=True),
ConvBlock(in_channels=features, **kwargs),
ConvBlock(in_channels=features*2, **kwargs),
ConvBlock(in_channels=features*4, **kwargs),
ConvBlock(in_channels=features*8, **kwargs),
ConvBlock(in_channels=features*16, transpose=True, **kwargs),
ConvBlock(in_channels=features * 8, transpose=True, **kwargs),
ConvBlock(in_channels=features * 4, transpose=True, **kwargs),
ConvBlock(in_channels=features * 2, transpose=True, **kwargs),
self.last_layer,
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
if __name__ == "__main__":
model = AutoEncoder(in_channels=1, out_channels=1, features=32, kernel_size=3, stride=1, padding=1, repeat_conv_num=2).to('cuda')
x = torch.randn((2, 1, 512, 512)).to('cuda')
summary(model, input_size=(1, 512, 512))
print(model(x).shape)
model2 = ConvBlock(in_channels=32, kernel_size=3, stride=1, padding=1, repeat_conv_num=2).to('cuda')
x = torch.randn((2, 32, 512, 512)).to('cuda')
summary(model2, input_size=(32, 512, 512))
print(model2(x).shape)
result:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 64, 64] 320
Conv2d-2 [-1, 32, 64, 64] 320
BatchNorm2d-3 [-1, 32, 64, 64] 64
ReLU-4 [-1, 32, 64, 64] 0
Conv2d-5 [-1, 32, 64, 64] 9,248
Conv2d-6 [-1, 32, 64, 64] 9,248
BatchNorm2d-7 [-1, 32, 64, 64] 64
BatchNorm2d-8 [-1, 32, 64, 64] 64
ReLU-9 [-1, 32, 64, 64] 0
ReLU-10 [-1, 32, 64, 64] 0
Conv2d-11 [-1, 32, 64, 64] 9,248
Conv2d-12 [-1, 32, 64, 64] 9,248
BatchNorm2d-13 [-1, 32, 64, 64] 64
BatchNorm2d-14 [-1, 32, 64, 64] 64
ReLU-15 [-1, 32, 64, 64] 0
ReLU-16 [-1, 32, 64, 64] 0
Conv2d-17 [-1, 64, 32, 32] 32,832
Conv2d-18 [-1, 64, 32, 32] 32,832
BatchNorm2d-19 [-1, 64, 32, 32] 128
BatchNorm2d-20 [-1, 64, 32, 32] 128
ReLU-21 [-1, 64, 32, 32] 0
ReLU-22 [-1, 64, 32, 32] 0
ConvBlock-23 [-1, 64, 32, 32] 0
Conv2d-24 [-1, 64, 32, 32] 36,928
Conv2d-25 [-1, 64, 32, 32] 36,928
BatchNorm2d-26 [-1, 64, 32, 32] 128
BatchNorm2d-27 [-1, 64, 32, 32] 128
ReLU-28 [-1, 64, 32, 32] 0
ReLU-29 [-1, 64, 32, 32] 0
Conv2d-30 [-1, 64, 32, 32] 36,928
Conv2d-31 [-1, 64, 32, 32] 36,928
BatchNorm2d-32 [-1, 64, 32, 32] 128
BatchNorm2d-33 [-1, 64, 32, 32] 128
ReLU-34 [-1, 64, 32, 32] 0
ReLU-35 [-1, 64, 32, 32] 0
Conv2d-36 [-1, 128, 16, 16] 131,200
Conv2d-37 [-1, 128, 16, 16] 131,200
BatchNorm2d-38 [-1, 128, 16, 16] 256
BatchNorm2d-39 [-1, 128, 16, 16] 256
ReLU-40 [-1, 128, 16, 16] 0
ReLU-41 [-1, 128, 16, 16] 0
ConvBlock-42 [-1, 128, 16, 16] 0
Conv2d-43 [-1, 128, 16, 16] 147,584
Conv2d-44 [-1, 128, 16, 16] 147,584
BatchNorm2d-45 [-1, 128, 16, 16] 256
BatchNorm2d-46 [-1, 128, 16, 16] 256
ReLU-47 [-1, 128, 16, 16] 0
ReLU-48 [-1, 128, 16, 16] 0
Conv2d-49 [-1, 128, 16, 16] 147,584
Conv2d-50 [-1, 128, 16, 16] 147,584
BatchNorm2d-51 [-1, 128, 16, 16] 256
BatchNorm2d-52 [-1, 128, 16, 16] 256
ReLU-53 [-1, 128, 16, 16] 0
ReLU-54 [-1, 128, 16, 16] 0
Conv2d-55 [-1, 256, 8, 8] 524,544
Conv2d-56 [-1, 256, 8, 8] 524,544
BatchNorm2d-57 [-1, 256, 8, 8] 512
BatchNorm2d-58 [-1, 256, 8, 8] 512
ReLU-59 [-1, 256, 8, 8] 0
ReLU-60 [-1, 256, 8, 8] 0
ConvBlock-61 [-1, 256, 8, 8] 0
Conv2d-62 [-1, 256, 8, 8] 590,080
Conv2d-63 [-1, 256, 8, 8] 590,080
BatchNorm2d-64 [-1, 256, 8, 8] 512
BatchNorm2d-65 [-1, 256, 8, 8] 512
ReLU-66 [-1, 256, 8, 8] 0
ReLU-67 [-1, 256, 8, 8] 0
Conv2d-68 [-1, 256, 8, 8] 590,080
Conv2d-69 [-1, 256, 8, 8] 590,080
BatchNorm2d-70 [-1, 256, 8, 8] 512
BatchNorm2d-71 [-1, 256, 8, 8] 512
ReLU-72 [-1, 256, 8, 8] 0
ReLU-73 [-1, 256, 8, 8] 0
Conv2d-74 [-1, 512, 4, 4] 2,097,664
Conv2d-75 [-1, 512, 4, 4] 2,097,664
BatchNorm2d-76 [-1, 512, 4, 4] 1,024
BatchNorm2d-77 [-1, 512, 4, 4] 1,024
ReLU-78 [-1, 512, 4, 4] 0
ReLU-79 [-1, 512, 4, 4] 0
ConvBlock-80 [-1, 512, 4, 4] 0
ConvTranspose2d-81 [-1, 512, 4, 4] 2,359,808
ConvTranspose2d-82 [-1, 512, 4, 4] 2,359,808
BatchNorm2d-83 [-1, 512, 4, 4] 1,024
BatchNorm2d-84 [-1, 512, 4, 4] 1,024
ReLU-85 [-1, 512, 4, 4] 0
ReLU-86 [-1, 512, 4, 4] 0
ConvTranspose2d-87 [-1, 512, 4, 4] 2,359,808
ConvTranspose2d-88 [-1, 512, 4, 4] 2,359,808
BatchNorm2d-89 [-1, 512, 4, 4] 1,024
BatchNorm2d-90 [-1, 512, 4, 4] 1,024
ReLU-91 [-1, 512, 4, 4] 0
ReLU-92 [-1, 512, 4, 4] 0
ConvTranspose2d-93 [-1, 256, 8, 8] 2,097,408
ConvTranspose2d-94 [-1, 256, 8, 8] 2,097,408
BatchNorm2d-95 [-1, 256, 8, 8] 512
BatchNorm2d-96 [-1, 256, 8, 8] 512
ReLU-97 [-1, 256, 8, 8] 0
ReLU-98 [-1, 256, 8, 8] 0
ConvBlock-99 [-1, 256, 8, 8] 0
ConvTranspose2d-100 [-1, 256, 8, 8] 590,080
ConvTranspose2d-101 [-1, 256, 8, 8] 590,080
BatchNorm2d-102 [-1, 256, 8, 8] 512
BatchNorm2d-103 [-1, 256, 8, 8] 512
ReLU-104 [-1, 256, 8, 8] 0
ReLU-105 [-1, 256, 8, 8] 0
ConvTranspose2d-106 [-1, 256, 8, 8] 590,080
ConvTranspose2d-107 [-1, 256, 8, 8] 590,080
BatchNorm2d-108 [-1, 256, 8, 8] 512
BatchNorm2d-109 [-1, 256, 8, 8] 512
ReLU-110 [-1, 256, 8, 8] 0
ReLU-111 [-1, 256, 8, 8] 0
ConvTranspose2d-112 [-1, 128, 16, 16] 524,416
ConvTranspose2d-113 [-1, 128, 16, 16] 524,416
BatchNorm2d-114 [-1, 128, 16, 16] 256
BatchNorm2d-115 [-1, 128, 16, 16] 256
ReLU-116 [-1, 128, 16, 16] 0
ReLU-117 [-1, 128, 16, 16] 0
ConvBlock-118 [-1, 128, 16, 16] 0
ConvTranspose2d-119 [-1, 128, 16, 16] 147,584
ConvTranspose2d-120 [-1, 128, 16, 16] 147,584
BatchNorm2d-121 [-1, 128, 16, 16] 256
BatchNorm2d-122 [-1, 128, 16, 16] 256
ReLU-123 [-1, 128, 16, 16] 0
ReLU-124 [-1, 128, 16, 16] 0
ConvTranspose2d-125 [-1, 128, 16, 16] 147,584
ConvTranspose2d-126 [-1, 128, 16, 16] 147,584
BatchNorm2d-127 [-1, 128, 16, 16] 256
BatchNorm2d-128 [-1, 128, 16, 16] 256
ReLU-129 [-1, 128, 16, 16] 0
ReLU-130 [-1, 128, 16, 16] 0
ConvTranspose2d-131 [-1, 64, 32, 32] 131,136
ConvTranspose2d-132 [-1, 64, 32, 32] 131,136
BatchNorm2d-133 [-1, 64, 32, 32] 128
BatchNorm2d-134 [-1, 64, 32, 32] 128
ReLU-135 [-1, 64, 32, 32] 0
ReLU-136 [-1, 64, 32, 32] 0
ConvBlock-137 [-1, 64, 32, 32] 0
ConvTranspose2d-138 [-1, 64, 32, 32] 36,928
ConvTranspose2d-139 [-1, 64, 32, 32] 36,928
BatchNorm2d-140 [-1, 64, 32, 32] 128
BatchNorm2d-141 [-1, 64, 32, 32] 128
ReLU-142 [-1, 64, 32, 32] 0
ReLU-143 [-1, 64, 32, 32] 0
ConvTranspose2d-144 [-1, 64, 32, 32] 36,928
ConvTranspose2d-145 [-1, 64, 32, 32] 36,928
BatchNorm2d-146 [-1, 64, 32, 32] 128
BatchNorm2d-147 [-1, 64, 32, 32] 128
ReLU-148 [-1, 64, 32, 32] 0
ReLU-149 [-1, 64, 32, 32] 0
ConvTranspose2d-150 [-1, 32, 64, 64] 32,800
ConvTranspose2d-151 [-1, 32, 64, 64] 32,800
BatchNorm2d-152 [-1, 32, 64, 64] 64
BatchNorm2d-153 [-1, 32, 64, 64] 64
ReLU-154 [-1, 32, 64, 64] 0
ReLU-155 [-1, 32, 64, 64] 0
ConvBlock-156 [-1, 32, 64, 64] 0
Conv2d-157 [-1, 1, 64, 64] 289
Conv2d-158 [-1, 1, 64, 64] 289
Sigmoid-159 [-1, 1, 64, 64] 0
================================================================
Total params: 26,835,522
Trainable params: 26,835,522
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 57.53
Params size (MB): 102.37
Estimated Total Size (MB): 159.92
----------------------------------------------------------------
torch.Size([2, 1, 64, 64])