Hello, I am trying to implement a encoder decoder network in Pytorch, like the Tensorflow implementation here:
def EncoderMiniBlock(inputs, n_filters=32, dropout_prob=0.3, max_pooling=True):
"""
This block uses multiple convolution layers, max pool, relu activation to create an architecture for learning.
Dropout can be added for regularization to prevent overfitting.
The block returns the activation values for next layer along with a skip connection which will be used in the decoder
"""
# Add 2 Conv Layers with relu activation and HeNormal initialization using TensorFlow
# Proper initialization prevents from the problem of exploding and vanishing gradients
# 'Same' padding will pad the input to conv layer such that the output has the same height and width (hence, is not reduced in size)
conv = Conv1D(n_filters,
9, # Kernel size
activation='relu',
padding='same',
kernel_initializer='HeNormal')(inputs)
conv = Conv1D(n_filters,
9, # Kernel size
activation='relu',
padding='same',
kernel_initializer='HeNormal')(conv)
# Batch Normalization will normalize the output of the last layer based on the batch's mean and standard deviation
conv = BatchNormalization()(conv, training=False)
# In case of overfitting, dropout will regularize the loss and gradient computation to shrink the influence of weights on output
if dropout_prob > 0:
conv = tf.keras.layers.Dropout(dropout_prob)(conv)
# Pooling reduces the size of the image while keeping the number of channels same
# Pooling has been kept as optional as the last encoder layer does not use pooling (hence, makes the encoder block flexible to use)
# Below, Max pooling considers the maximum of the input slice for output computation and uses stride of 2 to traverse across input image
if max_pooling:
next_layer = tf.keras.layers.MaxPooling1D(pool_size = 2)(conv)
else:
next_layer = conv
# skip connection (without max pooling) will be input to the decoder layer to prevent information loss during transpose convolutions
skip_connection = conv
return next_layer, skip_connection
def DecoderMiniBlock(prev_layer_input, skip_layer_input, n_filters=32):
"""
Decoder Block first uses transpose convolution to upscale the image to a bigger size and then,
merges the result with skip layer results from encoder block
Adding 2 convolutions with 'same' padding helps further increase the depth of the network for better predictions
The function returns the decoded layer output
"""
# Start with a transpose convolution layer to first increase the size of the image
up = Conv1DTranspose(
n_filters,
9, # Kernel size
strides=2,
padding='same')(prev_layer_input)
# Merge the skip connection from previous block to prevent information loss
merge = concatenate([up, skip_layer_input], axis=2)
# Add 2 Conv Layers with relu activation and HeNormal initialization for further processing
# The parameters for the function are similar to encoder
conv = Conv1D(n_filters,
9, # Kernel size
activation='relu',
padding='same',
kernel_initializer='HeNormal')(merge)
conv = Conv1D(n_filters,
9, # Kernel size
activation='relu',
padding='same',
kernel_initializer='HeNormal')(conv)
return conv
def UNetCompiled(input_size=(960, 1), n_filters=32):
# Input size represent the size of 1 image (the size used for pre-processing)
inputs = Input(input_size)
# Encoder includes multiple convolutional mini blocks with different maxpooling, dropout and filter parameters
# Observe that the filters are increasing as we go deeper into the network which will increasse the # channels of the image
cblock1 = EncoderMiniBlock(inputs, n_filters,dropout_prob=0, max_pooling=True)
cblock2 = EncoderMiniBlock(cblock1[0],n_filters*2,dropout_prob=0, max_pooling=True)
cblock3 = EncoderMiniBlock(cblock2[0], n_filters*4,dropout_prob=0, max_pooling=True)
cblock4 = EncoderMiniBlock(cblock3[0], n_filters*8,dropout_prob=0.3, max_pooling=True)
cblock5 = EncoderMiniBlock(cblock4[0], n_filters*16, dropout_prob=0.3, max_pooling=False)
# Decoder includes multiple mini blocks with decreasing number of filters
# Observe the skip connections from the encoder are given as input to the decoder
# Recall the 2nd output of encoder block was skip connection, hence cblockn[1] is used
ublock6 = DecoderMiniBlock(cblock5[0], cblock4[1], n_filters * 8)
ublock7 = DecoderMiniBlock(ublock6, cblock3[1], n_filters * 4)
ublock8 = DecoderMiniBlock(ublock7, cblock2[1], n_filters * 2)
ublock9 = DecoderMiniBlock(ublock8, cblock1[1], n_filters)
# Complete the model with 1 3x3 convolution layer (Same as the prev Conv Layers)
# Followed by a 1x1 Conv layer to get the image to the desired size.
# Observe the number of channels will be equal to number of output classes
conv9 = Conv1D(n_filters,
9,
activation='relu',
padding='same',
kernel_initializer='he_normal')(ublock9)
conv10 = Conv1D(1, 1, padding='same')(conv9)
# Define the model
model = tf.keras.Model(inputs=inputs, outputs=conv10)
return model
Which gives the following parameters:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 960, 1)] 0 []
conv1d (Conv1D) (None, 960, 32) 320 ['input_1[0][0]']
conv1d_1 (Conv1D) (None, 960, 32) 9248 ['conv1d[0][0]']
batch_normalization (BatchNorm (None, 960, 32) 128 ['conv1d_1[0][0]']
alization)
max_pooling1d (MaxPooling1D) (None, 480, 32) 0 ['batch_normalization[0][0]']
conv1d_2 (Conv1D) (None, 480, 64) 18496 ['max_pooling1d[0][0]']
conv1d_3 (Conv1D) (None, 480, 64) 36928 ['conv1d_2[0][0]']
batch_normalization_1 (BatchNo (None, 480, 64) 256 ['conv1d_3[0][0]']
rmalization)
max_pooling1d_1 (MaxPooling1D) (None, 240, 64) 0 ['batch_normalization_1[0][0]']
conv1d_4 (Conv1D) (None, 240, 128) 73856 ['max_pooling1d_1[0][0]']
conv1d_5 (Conv1D) (None, 240, 128) 147584 ['conv1d_4[0][0]']
batch_normalization_2 (BatchNo (None, 240, 128) 512 ['conv1d_5[0][0]']
rmalization)
max_pooling1d_2 (MaxPooling1D) (None, 120, 128) 0 ['batch_normalization_2[0][0]']
conv1d_6 (Conv1D) (None, 120, 256) 295168 ['max_pooling1d_2[0][0]']
conv1d_7 (Conv1D) (None, 120, 256) 590080 ['conv1d_6[0][0]']
batch_normalization_3 (BatchNo (None, 120, 256) 1024 ['conv1d_7[0][0]']
rmalization)
dropout (Dropout) (None, 120, 256) 0 ['batch_normalization_3[0][0]']
max_pooling1d_3 (MaxPooling1D) (None, 60, 256) 0 ['dropout[0][0]']
conv1d_8 (Conv1D) (None, 60, 512) 1180160 ['max_pooling1d_3[0][0]']
conv1d_9 (Conv1D) (None, 60, 512) 2359808 ['conv1d_8[0][0]']
batch_normalization_4 (BatchNo (None, 60, 512) 2048 ['conv1d_9[0][0]']
rmalization)
dropout_1 (Dropout) (None, 60, 512) 0 ['batch_normalization_4[0][0]']
conv1d_transpose (Conv1DTransp (None, 120, 256) 1179904 ['dropout_1[0][0]']
ose)
concatenate (Concatenate) (None, 120, 512) 0 ['conv1d_transpose[0][0]',
'dropout[0][0]']
conv1d_10 (Conv1D) (None, 120, 256) 1179904 ['concatenate[0][0]']
conv1d_11 (Conv1D) (None, 120, 256) 590080 ['conv1d_10[0][0]']
conv1d_transpose_1 (Conv1DTran (None, 240, 128) 295040 ['conv1d_11[0][0]']
spose)
concatenate_1 (Concatenate) (None, 240, 256) 0 ['conv1d_transpose_1[0][0]',
'batch_normalization_2[0][0]']
conv1d_12 (Conv1D) (None, 240, 128) 295040 ['concatenate_1[0][0]']
conv1d_13 (Conv1D) (None, 240, 128) 147584 ['conv1d_12[0][0]']
conv1d_transpose_2 (Conv1DTran (None, 480, 64) 73792 ['conv1d_13[0][0]']
spose)
concatenate_2 (Concatenate) (None, 480, 128) 0 ['conv1d_transpose_2[0][0]',
'batch_normalization_1[0][0]']
conv1d_14 (Conv1D) (None, 480, 64) 73792 ['concatenate_2[0][0]']
conv1d_15 (Conv1D) (None, 480, 64) 36928 ['conv1d_14[0][0]']
conv1d_transpose_3 (Conv1DTran (None, 960, 32) 18464 ['conv1d_15[0][0]']
spose)
concatenate_3 (Concatenate) (None, 960, 64) 0 ['conv1d_transpose_3[0][0]',
'batch_normalization[0][0]']
conv1d_16 (Conv1D) (None, 960, 32) 18464 ['concatenate_3[0][0]']
conv1d_17 (Conv1D) (None, 960, 32) 9248 ['conv1d_16[0][0]']
conv1d_18 (Conv1D) (None, 960, 32) 9248 ['conv1d_17[0][0]']
conv1d_19 (Conv1D) (None, 960, 1) 33 ['conv1d_18[0][0]']
==================================================================================================
Total params: 8,643,137
Trainable params: 8,641,153
Non-trainable params: 1,984
__________________________________________________________________________________________________
This is what I managed to do in Pytorch:
class EncoderBlock(nn.Module):
def __init__(self,input,n_filters,dropout_prob=0.3, max_pooling = True,padding="same"):
super(EncoderBlock,self).__init__()
self.dropout_prob = dropout_prob
self.conv1 = nn.Conv1d(input,n_filters,kernel_size=9,padding=padding)
self.conv2 = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding=padding)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.dropout = nn.Dropout(dropout_prob)
self.max_pooling = max_pooling
self.max_pool = nn.MaxPool1d(2)
self.batch_norm = nn.BatchNorm1d(n_filters,affine=True)
init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
init.zeros_(self.conv1.bias)
init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
init.zeros_(self.conv2.bias)
def forward(self,x):
x = self.conv1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.batch_norm(x)
if self.dropout_prob > 0:
x = self.dropout(x)
if self.max_pooling:
x = self.max_pool(x)
else:
x = x
skip_connection = x
return x, skip_connection
class DecoderBlock(nn.Module):
def __init__(self,input,n_filters,padding="same"):
super(DecoderBlock,self).__init__()
self.conv1 = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding=padding)
self.conv2 = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding=padding)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.upsample = nn.ConvTranspose1d(input,n_filters,kernel_size=9,stride=1,padding=4)
init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
init.zeros_(self.conv1.bias)
init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
init.zeros_(self.conv2.bias)
def forward(self,x,skip_connection):
x = self.upsample(x)
merge = torch.cat([x,skip_connection],dim=2)
x = self.conv1(merge)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
return x
class WaveLike(nn.Module):
def __init__(self, n_filters=32, dropout_prob=0.3):
super(WaveLike, self).__init__()
self.n_filters = n_filters
self.dropout_prob = dropout_prob
self.encoder1 = EncoderBlock(1,n_filters,dropout_prob=0,max_pooling=True)
self.encoder2 = EncoderBlock(n_filters,n_filters*2,dropout_prob=0,max_pooling=True)
self.encoder3 = EncoderBlock(n_filters*2,n_filters*4,dropout_prob=0,max_pooling=True)
self.encoder4 = EncoderBlock(n_filters*4,n_filters*8,dropout_prob=0.3,max_pooling=True)
self.encoder5 = EncoderBlock(n_filters*8,n_filters*16,dropout_prob=0.3,max_pooling=False)
self.decoder1 = DecoderBlock(n_filters*16,n_filters*8)
self.decoder2 = DecoderBlock(n_filters*8,n_filters*4)
self.decoder3 = DecoderBlock(n_filters*4,n_filters*2)
self.decoder4 = DecoderBlock(n_filters*2,n_filters)
self.conv = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding="same")
self.relu = nn.ReLU()
self.last_conv = nn.Conv1d(n_filters,1,kernel_size=1,padding="same")
init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
init.zeros_(self.conv.bias)
def forward(self,x):
x1= self.encoder1(x)
x2 = self.encoder2(x1[0])
x3 = self.encoder3(x2[0])
x4 = self.encoder4(x3[0])
x5 = self.encoder5(x4[0])
z6 = self.decoder1(x5[0],x4[1])
z7 = self.decoder2(z6,x3[1])
z8 = self.decoder3(z7,x2[1])
z9 = self.decoder4(z8,x1[1])
conv = self.conv(z9)
conv = self.relu(conv)
conv = self.last_conv(conv)
return conv
which gives me the following summary:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv1d-1 [-1, 32, 1024] 320
ReLU-2 [-1, 32, 1024] 0
Conv1d-3 [-1, 32, 1024] 9,248
ReLU-4 [-1, 32, 1024] 0
BatchNorm1d-5 [-1, 32, 1024] 0
MaxPool1d-6 [-1, 32, 512] 0
EncoderBlock-7 [[-1, 32, 512], [-1, 32, 512]] 0
Conv1d-8 [-1, 64, 512] 18,496
ReLU-9 [-1, 64, 512] 0
Conv1d-10 [-1, 64, 512] 36,928
ReLU-11 [-1, 64, 512] 0
BatchNorm1d-12 [-1, 64, 512] 0
MaxPool1d-13 [-1, 64, 256] 0
EncoderBlock-14 [[-1, 64, 256], [-1, 64, 256]] 0
Conv1d-15 [-1, 128, 256] 73,856
ReLU-16 [-1, 128, 256] 0
Conv1d-17 [-1, 128, 256] 147,584
ReLU-18 [-1, 128, 256] 0
BatchNorm1d-19 [-1, 128, 256] 0
MaxPool1d-20 [-1, 128, 128] 0
EncoderBlock-21 [[-1, 128, 128], [-1, 128, 128]] 0
Conv1d-22 [-1, 256, 128] 295,168
ReLU-23 [-1, 256, 128] 0
Conv1d-24 [-1, 256, 128] 590,080
ReLU-25 [-1, 256, 128] 0
BatchNorm1d-26 [-1, 256, 128] 0
Dropout-27 [-1, 256, 128] 0
MaxPool1d-28 [-1, 256, 64] 0
EncoderBlock-29 [[-1, 256, 64], [-1, 256, 64]] 0
Conv1d-30 [-1, 512, 64] 1,180,160
ReLU-31 [-1, 512, 64] 0
Conv1d-32 [-1, 512, 64] 2,359,808
ReLU-33 [-1, 512, 64] 0
BatchNorm1d-34 [-1, 512, 64] 0
Dropout-35 [-1, 512, 64] 0
EncoderBlock-36 [[-1, 512, 64], [-1, 512, 64]] 0
ConvTranspose1d-37 [-1, 256, 64] 1,179,904
Conv1d-38 [-1, 256, 128] 590,080
ReLU-39 [-1, 256, 128] 0
Conv1d-40 [-1, 256, 128] 590,080
ReLU-41 [-1, 256, 128] 0
DecoderBlock-42 [-1, 256, 128] 0
ConvTranspose1d-43 [-1, 128, 128] 295,040
Conv1d-44 [-1, 128, 256] 147,584
ReLU-45 [-1, 128, 256] 0
Conv1d-46 [-1, 128, 256] 147,584
ReLU-47 [-1, 128, 256] 0
DecoderBlock-48 [-1, 128, 256] 0
ConvTranspose1d-49 [-1, 64, 256] 73,792
Conv1d-50 [-1, 64, 512] 36,928
ReLU-51 [-1, 64, 512] 0
Conv1d-52 [-1, 64, 512] 36,928
ReLU-53 [-1, 64, 512] 0
DecoderBlock-54 [-1, 64, 512] 0
ConvTranspose1d-55 [-1, 32, 512] 18,464
Conv1d-56 [-1, 32, 1024] 9,248
ReLU-57 [-1, 32, 1024] 0
Conv1d-58 [-1, 32, 1024] 9,248
ReLU-59 [-1, 32, 1024] 0
DecoderBlock-60 [-1, 32, 1024] 0
Conv1d-61 [-1, 32, 1024] 9,248
ReLU-62 [-1, 32, 1024] 0
Conv1d-63 [-1, 1, 1024] 33
================================================================
Total params: 7,855,809
Trainable params: 7,855,809
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 16370.74
Params size (MB): 29.97
Estimated Total Size (MB): 16400.71
----------------------------------------------------------------
I know that the input size is different, but why is it that the batch normalization returns a different nr of parameters for the tensorflow implementation and it has more parameters as well as non trainable parameters?