What are the differences between these pytorch and the tensorflow implementation?

Hello, I am trying to implement a encoder decoder network in Pytorch, like the Tensorflow implementation here:

def EncoderMiniBlock(inputs, n_filters=32, dropout_prob=0.3, max_pooling=True):
    """
    This block uses multiple convolution layers, max pool, relu activation to create an architecture for learning. 
    Dropout can be added for regularization to prevent overfitting. 
    The block returns the activation values for next layer along with a skip connection which will be used in the decoder
    """
    # Add 2 Conv Layers with relu activation and HeNormal initialization using TensorFlow 
    # Proper initialization prevents from the problem of exploding and vanishing gradients 
    # 'Same' padding will pad the input to conv layer such that the output has the same height and width (hence, is not reduced in size) 
    conv = Conv1D(n_filters, 
                  9,   # Kernel size   
                  activation='relu',
                  padding='same',
                  kernel_initializer='HeNormal')(inputs)
    conv = Conv1D(n_filters, 
                  9,   # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='HeNormal')(conv)
    
    # Batch Normalization will normalize the output of the last layer based on the batch's mean and standard deviation
    conv = BatchNormalization()(conv, training=False)

    # In case of overfitting, dropout will regularize the loss and gradient computation to shrink the influence of weights on output
    if dropout_prob > 0:     
        conv = tf.keras.layers.Dropout(dropout_prob)(conv)

    # Pooling reduces the size of the image while keeping the number of channels same
    # Pooling has been kept as optional as the last encoder layer does not use pooling (hence, makes the encoder block flexible to use)
    # Below, Max pooling considers the maximum of the input slice for output computation and uses stride of 2 to traverse across input image
    if max_pooling:
        next_layer = tf.keras.layers.MaxPooling1D(pool_size = 2)(conv)    
    else:
        next_layer = conv

    # skip connection (without max pooling) will be input to the decoder layer to prevent information loss during transpose convolutions      
    skip_connection = conv
    
    return next_layer, skip_connection

def DecoderMiniBlock(prev_layer_input, skip_layer_input, n_filters=32):
    """
    Decoder Block first uses transpose convolution to upscale the image to a bigger size and then,
    merges the result with skip layer results from encoder block
    Adding 2 convolutions with 'same' padding helps further increase the depth of the network for better predictions
    The function returns the decoded layer output
    """
    # Start with a transpose convolution layer to first increase the size of the image
    up = Conv1DTranspose(
                 n_filters,
                 9,    # Kernel size
                 strides=2,
                 padding='same')(prev_layer_input)

    # Merge the skip connection from previous block to prevent information loss
    merge = concatenate([up, skip_layer_input], axis=2)
    
    # Add 2 Conv Layers with relu activation and HeNormal initialization for further processing
    # The parameters for the function are similar to encoder
    conv = Conv1D(n_filters, 
                 9,     # Kernel size
                 activation='relu',
                 padding='same',
                 kernel_initializer='HeNormal')(merge)
    conv = Conv1D(n_filters,
                 9,   # Kernel size
                 activation='relu',
                 padding='same',
                 kernel_initializer='HeNormal')(conv)
    return conv

def UNetCompiled(input_size=(960, 1), n_filters=32):
    # Input size represent the size of 1 image (the size used for pre-processing) 
    inputs = Input(input_size)

    # Encoder includes multiple convolutional mini blocks with different maxpooling, dropout and filter parameters
    # Observe that the filters are increasing as we go deeper into the network which will increasse the # channels of the image 
    cblock1 = EncoderMiniBlock(inputs, n_filters,dropout_prob=0, max_pooling=True)
    cblock2 = EncoderMiniBlock(cblock1[0],n_filters*2,dropout_prob=0, max_pooling=True)
    cblock3 = EncoderMiniBlock(cblock2[0], n_filters*4,dropout_prob=0, max_pooling=True)
    cblock4 = EncoderMiniBlock(cblock3[0], n_filters*8,dropout_prob=0.3, max_pooling=True)
    cblock5 = EncoderMiniBlock(cblock4[0], n_filters*16, dropout_prob=0.3, max_pooling=False) 

    # Decoder includes multiple mini blocks with decreasing number of filters
    # Observe the skip connections from the encoder are given as input to the decoder
    # Recall the 2nd output of encoder block was skip connection, hence cblockn[1] is used
    ublock6 = DecoderMiniBlock(cblock5[0], cblock4[1],  n_filters * 8)
    ublock7 = DecoderMiniBlock(ublock6, cblock3[1],  n_filters * 4)
    ublock8 = DecoderMiniBlock(ublock7, cblock2[1],  n_filters * 2)
    ublock9 = DecoderMiniBlock(ublock8, cblock1[1],  n_filters)

    # Complete the model with 1 3x3 convolution layer (Same as the prev Conv Layers)
    # Followed by a 1x1 Conv layer to get the image to the desired size. 
    # Observe the number of channels will be equal to number of output classes
    conv9 = Conv1D(n_filters,
                 9,
                 activation='relu',
                 padding='same',
                 kernel_initializer='he_normal')(ublock9)

    conv10 = Conv1D(1, 1, padding='same')(conv9)

    # Define the model
    model = tf.keras.Model(inputs=inputs, outputs=conv10)

    return model

Which gives the following parameters:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 960, 1)]     0           []                               
                                                                                                  
 conv1d (Conv1D)                (None, 960, 32)      320         ['input_1[0][0]']                
                                                                                                  
 conv1d_1 (Conv1D)              (None, 960, 32)      9248        ['conv1d[0][0]']                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 960, 32)     128         ['conv1d_1[0][0]']               
 alization)                                                                                       
                                                                                                  
 max_pooling1d (MaxPooling1D)   (None, 480, 32)      0           ['batch_normalization[0][0]']    
                                                                                                  
 conv1d_2 (Conv1D)              (None, 480, 64)      18496       ['max_pooling1d[0][0]']          
                                                                                                  
 conv1d_3 (Conv1D)              (None, 480, 64)      36928       ['conv1d_2[0][0]']               
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 480, 64)     256         ['conv1d_3[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 max_pooling1d_1 (MaxPooling1D)  (None, 240, 64)     0           ['batch_normalization_1[0][0]']  
                                                                                                  
 conv1d_4 (Conv1D)              (None, 240, 128)     73856       ['max_pooling1d_1[0][0]']        
                                                                                                  
 conv1d_5 (Conv1D)              (None, 240, 128)     147584      ['conv1d_4[0][0]']               
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 240, 128)    512         ['conv1d_5[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 max_pooling1d_2 (MaxPooling1D)  (None, 120, 128)    0           ['batch_normalization_2[0][0]']  
                                                                                                  
 conv1d_6 (Conv1D)              (None, 120, 256)     295168      ['max_pooling1d_2[0][0]']        
                                                                                                  
 conv1d_7 (Conv1D)              (None, 120, 256)     590080      ['conv1d_6[0][0]']               
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 120, 256)    1024        ['conv1d_7[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 dropout (Dropout)              (None, 120, 256)     0           ['batch_normalization_3[0][0]']  
                                                                                                  
 max_pooling1d_3 (MaxPooling1D)  (None, 60, 256)     0           ['dropout[0][0]']                
                                                                                                  
 conv1d_8 (Conv1D)              (None, 60, 512)      1180160     ['max_pooling1d_3[0][0]']        
                                                                                                  
 conv1d_9 (Conv1D)              (None, 60, 512)      2359808     ['conv1d_8[0][0]']               
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 60, 512)     2048        ['conv1d_9[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 dropout_1 (Dropout)            (None, 60, 512)      0           ['batch_normalization_4[0][0]']  
                                                                                                  
 conv1d_transpose (Conv1DTransp  (None, 120, 256)    1179904     ['dropout_1[0][0]']              
 ose)                                                                                             
                                                                                                  
 concatenate (Concatenate)      (None, 120, 512)     0           ['conv1d_transpose[0][0]',       
                                                                  'dropout[0][0]']                
                                                                                                  
 conv1d_10 (Conv1D)             (None, 120, 256)     1179904     ['concatenate[0][0]']            
                                                                                                  
 conv1d_11 (Conv1D)             (None, 120, 256)     590080      ['conv1d_10[0][0]']              
                                                                                                  
 conv1d_transpose_1 (Conv1DTran  (None, 240, 128)    295040      ['conv1d_11[0][0]']              
 spose)                                                                                           
                                                                                                  
 concatenate_1 (Concatenate)    (None, 240, 256)     0           ['conv1d_transpose_1[0][0]',     
                                                                  'batch_normalization_2[0][0]']  
                                                                                                  
 conv1d_12 (Conv1D)             (None, 240, 128)     295040      ['concatenate_1[0][0]']          
                                                                                                  
 conv1d_13 (Conv1D)             (None, 240, 128)     147584      ['conv1d_12[0][0]']              
                                                                                                  
 conv1d_transpose_2 (Conv1DTran  (None, 480, 64)     73792       ['conv1d_13[0][0]']              
 spose)                                                                                           
                                                                                                  
 concatenate_2 (Concatenate)    (None, 480, 128)     0           ['conv1d_transpose_2[0][0]',     
                                                                  'batch_normalization_1[0][0]']  
                                                                                                  
 conv1d_14 (Conv1D)             (None, 480, 64)      73792       ['concatenate_2[0][0]']          
                                                                                                  
 conv1d_15 (Conv1D)             (None, 480, 64)      36928       ['conv1d_14[0][0]']              
                                                                                                  
 conv1d_transpose_3 (Conv1DTran  (None, 960, 32)     18464       ['conv1d_15[0][0]']              
 spose)                                                                                           
                                                                                                  
 concatenate_3 (Concatenate)    (None, 960, 64)      0           ['conv1d_transpose_3[0][0]',     
                                                                  'batch_normalization[0][0]']    
                                                                                                  
 conv1d_16 (Conv1D)             (None, 960, 32)      18464       ['concatenate_3[0][0]']          
                                                                                                  
 conv1d_17 (Conv1D)             (None, 960, 32)      9248        ['conv1d_16[0][0]']              
                                                                                                  
 conv1d_18 (Conv1D)             (None, 960, 32)      9248        ['conv1d_17[0][0]']              
                                                                                                  
 conv1d_19 (Conv1D)             (None, 960, 1)       33          ['conv1d_18[0][0]']              
                                                                                                  
==================================================================================================
Total params: 8,643,137
Trainable params: 8,641,153
Non-trainable params: 1,984
__________________________________________________________________________________________________

This is what I managed to do in Pytorch:

class EncoderBlock(nn.Module):
    def __init__(self,input,n_filters,dropout_prob=0.3, max_pooling = True,padding="same"):
        super(EncoderBlock,self).__init__()
        self.dropout_prob = dropout_prob
        self.conv1 = nn.Conv1d(input,n_filters,kernel_size=9,padding=padding)
        self.conv2 = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding=padding)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.dropout = nn.Dropout(dropout_prob)
        self.max_pooling = max_pooling
        self.max_pool = nn.MaxPool1d(2)
        self.batch_norm = nn.BatchNorm1d(n_filters,affine=True)
        init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        init.zeros_(self.conv1.bias)
        init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        init.zeros_(self.conv2.bias)
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.batch_norm(x)
        if self.dropout_prob > 0:
            x = self.dropout(x)
        if self.max_pooling:
            x = self.max_pool(x)
        else:
            x = x
        skip_connection = x
        return x, skip_connection

class DecoderBlock(nn.Module):
    def __init__(self,input,n_filters,padding="same"):
        super(DecoderBlock,self).__init__()
        self.conv1 = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding=padding)
        self.conv2 = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding=padding)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.upsample = nn.ConvTranspose1d(input,n_filters,kernel_size=9,stride=1,padding=4)
        init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        init.zeros_(self.conv1.bias)
        init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        init.zeros_(self.conv2.bias)


    def forward(self,x,skip_connection):
        x = self.upsample(x)
        merge = torch.cat([x,skip_connection],dim=2)
        x = self.conv1(merge)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x
    

class WaveLike(nn.Module):
    def __init__(self, n_filters=32, dropout_prob=0.3):
        super(WaveLike, self).__init__()
        self.n_filters = n_filters
        self.dropout_prob = dropout_prob
        self.encoder1 = EncoderBlock(1,n_filters,dropout_prob=0,max_pooling=True)
        self.encoder2 = EncoderBlock(n_filters,n_filters*2,dropout_prob=0,max_pooling=True)
        self.encoder3 = EncoderBlock(n_filters*2,n_filters*4,dropout_prob=0,max_pooling=True)
        self.encoder4 = EncoderBlock(n_filters*4,n_filters*8,dropout_prob=0.3,max_pooling=True)
        self.encoder5 = EncoderBlock(n_filters*8,n_filters*16,dropout_prob=0.3,max_pooling=False)
        self.decoder1 = DecoderBlock(n_filters*16,n_filters*8)
        self.decoder2 = DecoderBlock(n_filters*8,n_filters*4)
        self.decoder3 = DecoderBlock(n_filters*4,n_filters*2)
        self.decoder4 = DecoderBlock(n_filters*2,n_filters)

        self.conv = nn.Conv1d(n_filters,n_filters,kernel_size=9,padding="same")
        self.relu = nn.ReLU()
        self.last_conv = nn.Conv1d(n_filters,1,kernel_size=1,padding="same")

        init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
        init.zeros_(self.conv.bias)


    def forward(self,x):
        x1= self.encoder1(x)
    
        x2 = self.encoder2(x1[0])

        x3 = self.encoder3(x2[0])
        
        x4 = self.encoder4(x3[0])
        
        x5 = self.encoder5(x4[0])
        
        z6 = self.decoder1(x5[0],x4[1])
        
        z7 = self.decoder2(z6,x3[1])
        
        z8 = self.decoder3(z7,x2[1])
        
        z9 = self.decoder4(z8,x1[1])
        
        conv = self.conv(z9)
        conv = self.relu(conv)
        conv = self.last_conv(conv)
        
        return conv

which gives me the following summary:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv1d-1             [-1, 32, 1024]             320
              ReLU-2             [-1, 32, 1024]               0
            Conv1d-3             [-1, 32, 1024]           9,248
              ReLU-4             [-1, 32, 1024]               0
       BatchNorm1d-5             [-1, 32, 1024]               0
         MaxPool1d-6              [-1, 32, 512]               0
      EncoderBlock-7  [[-1, 32, 512], [-1, 32, 512]]               0
            Conv1d-8              [-1, 64, 512]          18,496
              ReLU-9              [-1, 64, 512]               0
           Conv1d-10              [-1, 64, 512]          36,928
             ReLU-11              [-1, 64, 512]               0
      BatchNorm1d-12              [-1, 64, 512]               0
        MaxPool1d-13              [-1, 64, 256]               0
     EncoderBlock-14  [[-1, 64, 256], [-1, 64, 256]]               0
           Conv1d-15             [-1, 128, 256]          73,856
             ReLU-16             [-1, 128, 256]               0
           Conv1d-17             [-1, 128, 256]         147,584
             ReLU-18             [-1, 128, 256]               0
      BatchNorm1d-19             [-1, 128, 256]               0
        MaxPool1d-20             [-1, 128, 128]               0
     EncoderBlock-21  [[-1, 128, 128], [-1, 128, 128]]               0
           Conv1d-22             [-1, 256, 128]         295,168
             ReLU-23             [-1, 256, 128]               0
           Conv1d-24             [-1, 256, 128]         590,080
             ReLU-25             [-1, 256, 128]               0
      BatchNorm1d-26             [-1, 256, 128]               0
          Dropout-27             [-1, 256, 128]               0
        MaxPool1d-28              [-1, 256, 64]               0
     EncoderBlock-29  [[-1, 256, 64], [-1, 256, 64]]               0
           Conv1d-30              [-1, 512, 64]       1,180,160
             ReLU-31              [-1, 512, 64]               0
           Conv1d-32              [-1, 512, 64]       2,359,808
             ReLU-33              [-1, 512, 64]               0
      BatchNorm1d-34              [-1, 512, 64]               0
          Dropout-35              [-1, 512, 64]               0
     EncoderBlock-36  [[-1, 512, 64], [-1, 512, 64]]               0
  ConvTranspose1d-37              [-1, 256, 64]       1,179,904
           Conv1d-38             [-1, 256, 128]         590,080
             ReLU-39             [-1, 256, 128]               0
           Conv1d-40             [-1, 256, 128]         590,080
             ReLU-41             [-1, 256, 128]               0
     DecoderBlock-42             [-1, 256, 128]               0
  ConvTranspose1d-43             [-1, 128, 128]         295,040
           Conv1d-44             [-1, 128, 256]         147,584
             ReLU-45             [-1, 128, 256]               0
           Conv1d-46             [-1, 128, 256]         147,584
             ReLU-47             [-1, 128, 256]               0
     DecoderBlock-48             [-1, 128, 256]               0
  ConvTranspose1d-49              [-1, 64, 256]          73,792
           Conv1d-50              [-1, 64, 512]          36,928
             ReLU-51              [-1, 64, 512]               0
           Conv1d-52              [-1, 64, 512]          36,928
             ReLU-53              [-1, 64, 512]               0
     DecoderBlock-54              [-1, 64, 512]               0
  ConvTranspose1d-55              [-1, 32, 512]          18,464
           Conv1d-56             [-1, 32, 1024]           9,248
             ReLU-57             [-1, 32, 1024]               0
           Conv1d-58             [-1, 32, 1024]           9,248
             ReLU-59             [-1, 32, 1024]               0
     DecoderBlock-60             [-1, 32, 1024]               0
           Conv1d-61             [-1, 32, 1024]           9,248
             ReLU-62             [-1, 32, 1024]               0
           Conv1d-63              [-1, 1, 1024]              33
================================================================
Total params: 7,855,809
Trainable params: 7,855,809
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 16370.74
Params size (MB): 29.97
Estimated Total Size (MB): 16400.71
----------------------------------------------------------------

I know that the input size is different, but why is it that the batch normalization returns a different nr of parameters for the tensorflow implementation and it has more parameters as well as non trainable parameters?

I don’t know how you are printing the summary, but torchinfo.summary returns:

model = WaveLike()
summary(model)
# =================================================================
# Layer (type:depth-idx)                   Param #
# =================================================================
# WaveLike                                 --
# ├─EncoderBlock: 1-1                      --
# │    └─Conv1d: 2-1                       320
# │    └─Conv1d: 2-2                       9,248
# │    └─ReLU: 2-3                         --
# │    └─ReLU: 2-4                         --
# │    └─Dropout: 2-5                      --
# │    └─MaxPool1d: 2-6                    --
# │    └─BatchNorm1d: 2-7                  64
# ├─EncoderBlock: 1-2                      --
# │    └─Conv1d: 2-8                       18,496
# │    └─Conv1d: 2-9                       36,928
# │    └─ReLU: 2-10                        --
# │    └─ReLU: 2-11                        --
# │    └─Dropout: 2-12                     --
# │    └─MaxPool1d: 2-13                   --
# │    └─BatchNorm1d: 2-14                 128
# ├─EncoderBlock: 1-3                      --
# │    └─Conv1d: 2-15                      73,856
# │    └─Conv1d: 2-16                      147,584
# │    └─ReLU: 2-17                        --
# │    └─ReLU: 2-18                        --
# │    └─Dropout: 2-19                     --
# │    └─MaxPool1d: 2-20                   --
# │    └─BatchNorm1d: 2-21                 256
# ├─EncoderBlock: 1-4                      --
# │    └─Conv1d: 2-22                      295,168
# │    └─Conv1d: 2-23                      590,080
# │    └─ReLU: 2-24                        --
# │    └─ReLU: 2-25                        --
# │    └─Dropout: 2-26                     --
# │    └─MaxPool1d: 2-27                   --
# │    └─BatchNorm1d: 2-28                 512
# ├─EncoderBlock: 1-5                      --
# │    └─Conv1d: 2-29                      1,180,160
# │    └─Conv1d: 2-30                      2,359,808
# │    └─ReLU: 2-31                        --
# │    └─ReLU: 2-32                        --
# │    └─Dropout: 2-33                     --
# │    └─MaxPool1d: 2-34                   --
# │    └─BatchNorm1d: 2-35                 1,024
# ├─DecoderBlock: 1-6                      --
# │    └─Conv1d: 2-36                      590,080
# │    └─Conv1d: 2-37                      590,080
# │    └─ReLU: 2-38                        --
# │    └─ReLU: 2-39                        --
# │    └─ConvTranspose1d: 2-40             1,179,904
# ├─DecoderBlock: 1-7                      --
# │    └─Conv1d: 2-41                      147,584
# │    └─Conv1d: 2-42                      147,584
# │    └─ReLU: 2-43                        --
# │    └─ReLU: 2-44                        --
# │    └─ConvTranspose1d: 2-45             295,040
# ├─DecoderBlock: 1-8                      --
# │    └─Conv1d: 2-46                      36,928
# │    └─Conv1d: 2-47                      36,928
# │    └─ReLU: 2-48                        --
# │    └─ReLU: 2-49                        --
# │    └─ConvTranspose1d: 2-50             73,792
# ├─DecoderBlock: 1-9                      --
# │    └─Conv1d: 2-51                      9,248
# │    └─Conv1d: 2-52                      9,248
# │    └─ReLU: 2-53                        --
# │    └─ReLU: 2-54                        --
# │    └─ConvTranspose1d: 2-55             18,464
# ├─Conv1d: 1-10                           9,248
# ├─ReLU: 1-11                             --
# ├─Conv1d: 1-12                           33
# =================================================================
# Total params: 7,857,793
# Trainable params: 7,857,793
# Non-trainable params: 0
# =================================================================

which shows the batchnorm parameters properly.

Thanks for the answer!

I am using
import torchsummary which is something I installed in addition to pytorch
and then I am doing this to get the output

model = WaveLike().to('cuda')

torchsummary.summary(model, (1, 1024))

torchsummary is deprecated as stated on the GitHub repository:

Use the new and updated torchinfo.

as it didn’t get any updates for a few years.

1 Like