May the same NN converge in Keras and not in PyTorch?

I am trying to implement a CNN for regression on images in PyTorch. I have a working model already implemented in Keras and I would like to translate it in PyTorch, but I am facing many issues.

Essentially, in Keras the model converges, whereas in PyTorch it doesn’t. In PyTorch I always have a constant training loss and the trained model always outputs the same value for any image. I initialized all the layers as in the working Keras model, I added l2 regularization as in Keras, I also implemented the same learning rate decay. Everything looks exactly the same, but in PyTorch my model doesn’t converge.

Is it possible that, with everything initialized the same way, a model does not converge in PyTorch while converging in Keras? If so, what would you suggest me to do in my specific case (constant training loss after one epoch and constant predictions)?

I have already tried to clip gradients and change learning rate, at the beginning I used Adam with lr = 0.001, then I tried 0.1 and 0.0001, always emulating Keras time-based decay.

Thank you in advance

Could you post some information about:

  • model architecture
  • input, output, target shapes
  • loss function used
  • pre-processing steps?

This would make it a bit easier to spot any potential bugs in your code.

The task consists in regressing steering angles from this Udacity dataset.

Preprocessing:

  • get grayscale version of images in [0,1] interval and crop them to 200x200 size with the following procedure:
half_the_width = int(img.shape[1] / 2)
img = img[img.shape[0] - crop_heigth: img.shape[0],      half_the_width - int(crop_width / 2):
half_the_width + int(crop_width / 2)]

Thus I have input images with size 200x200 and target of dimension 1 (steering angle).

Loss used is MSEloss

This is the Keras model that I want to emulate, it uses Adam with learning rate decay 1e-5 and is taken from DroNet repo:

def resnet8(img_width, img_height, img_channels, output_dim):
    """
    Define model architecture.
    
    # Arguments
       img_width: Target image widht.
       img_height: Target image height.
       img_channels: Target image channels.
       output_dim: Dimension of model output.
       
    # Returns
       model: A Model instance.
    """

    # Input
    img_input = Input(shape=(img_height, img_width, img_channels))

    x1 = Conv2D(32, (5, 5), strides=[2,2], padding='same')(img_input)
    x1 = MaxPooling2D(pool_size=(3, 3), strides=[2,2])(x1)

    # First residual block
    x2 = keras.layers.normalization.BatchNormalization()(x1)
    x2 = Activation('relu')(x2)
    x2 = Conv2D(32, (3, 3), strides=[2,2], padding='same',
                kernel_initializer="he_normal",
                kernel_regularizer=regularizers.l2(1e-4))(x2)

    x2 = keras.layers.normalization.BatchNormalization()(x2)
    x2 = Activation('relu')(x2)
    x2 = Conv2D(32, (3, 3), padding='same',
                kernel_initializer="he_normal",
                kernel_regularizer=regularizers.l2(1e-4))(x2)

    x1 = Conv2D(32, (1, 1), strides=[2,2], padding='same')(x1)
    x3 = add([x1, x2])

    # Second residual block
    x4 = keras.layers.normalization.BatchNormalization()(x3)
    x4 = Activation('relu')(x4)
    x4 = Conv2D(64, (3, 3), strides=[2,2], padding='same',
                kernel_initializer="he_normal",
                kernel_regularizer=regularizers.l2(1e-4))(x4)

    x4 = keras.layers.normalization.BatchNormalization()(x4)
    x4 = Activation('relu')(x4)
    x4 = Conv2D(64, (3, 3), padding='same',
                kernel_initializer="he_normal",
                kernel_regularizer=regularizers.l2(1e-4))(x4)

    x3 = Conv2D(64, (1, 1), strides=[2,2], padding='same')(x3)
    x5 = add([x3, x4])

    # Third residual block
    x6 = keras.layers.normalization.BatchNormalization()(x5)
    x6 = Activation('relu')(x6)
    x6 = Conv2D(128, (3, 3), strides=[2,2], padding='same',
                kernel_initializer="he_normal",
                kernel_regularizer=regularizers.l2(1e-4))(x6)

    x6 = keras.layers.normalization.BatchNormalization()(x6)
    x6 = Activation('relu')(x6)
    x6 = Conv2D(128, (3, 3), padding='same',
                kernel_initializer="he_normal",
                kernel_regularizer=regularizers.l2(1e-4))(x6)

    x5 = Conv2D(128, (1, 1), strides=[2,2], padding='same')(x5)
    x7 = add([x5, x6])

    x = Flatten()(x7)
    x = Activation('relu')(x)
    x = Dropout(0.5)(x)

    # Steering channel
    steer = Dense(output_dim)(x)

    # Collision channel
    coll = Dense(output_dim)(x)
    coll = Activation('sigmoid')(coll)

    # Define steering-collision model
    model = Model(inputs=[img_input], outputs=[steer, coll])
    print(model.summary())

    return model

In PyTorch I am trying to implement only steering angle prediction, the paper mentions that steering angles and collision predictions are uncorrelated.

This is my PyTorch implementation:
MODEL


def init_kernel(m):
    if isinstance(m, nn.Conv2d): 
        # Initialize kernels of Conv2d layers as kaiming normal
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        # Initialize biases of Conv2d layers at 0
        nn.init.zeros_(m.bias)
        
def __init__(self, img_channels, in_height, in_width, output_dim):
        super(resnet8, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=img_channels,out_channels=32, 
                      kernel_size=[5,5], stride=[2,2], padding=[5//2,5//2]),
            nn.MaxPool2d(kernel_size=[3,3], stride=[2,2]))
        
        self.residual_block_1a = nn.Sequential(
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32,out_channels=32, kernel_size=[3,3], 
                      stride=[2,2], padding=[3//2,3//2]), 
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32,out_channels=32, kernel_size=[3,3], 
                      padding=[3//2,3//2]))
        
        self.parallel_conv_1 = nn.Conv2d(in_channels=32,out_channels=32, 
                                         kernel_size=[1,1], stride=[2,2], 
                                         padding=[1//2,1//2])
        
        self.residual_block_2a = nn.Sequential(
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32,out_channels=64, kernel_size=[3,3], 
                      stride=[2,2], padding=[3//2,3//2]), 
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=64, kernel_size=[3,3], 
                      padding=[3//2,3//2]))
        
        

        self.parallel_conv_2 = nn.Conv2d(in_channels=32,out_channels=64, 
                                         kernel_size=[1,1], stride=[2,2], 
                                         padding=[1//2,1//2])
        
        self.residual_block_3a = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=128, kernel_size=[3,3], 
                      stride=[2,2], padding=[3//2,3//2]), 
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,out_channels=128, kernel_size=[3,3], 
                      padding=[3//2,3//2]))
        
        

        self.parallel_conv_3 = nn.Conv2d(in_channels=64,out_channels=128, 
                                         kernel_size=[1,1], stride=[2,2], 
                                         padding=[1//2,1//2])
        
        self.output_dim = output_dim

        self.last_block = nn.Sequential(
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Linear(6272,self.output_dim))
        
        # Initialize layers exactly as in Keras
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
                nn.init.zeros_(m.bias)    
            elif isinstance(m, nn.BatchNorm2d):
                # Initialize kernels of Conv2d layers as kaiming normal
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
        self.residual_block_1a.apply(init_kernel)
        self.residual_block_2a.apply(init_kernel)
        self.residual_block_3a.apply(init_kernel)

                
    def forward(self, x):
        x1 = self.layer1(x)
        # First residual block
        x2 = self.residual_block_1a(x1)
        x1 = self.parallel_conv_1(x1)
        x3 = x1.add(x2)
        # Second residual block
        x4 = self.residual_block_2a(x3)
        x3 = self.parallel_conv_2(x3)
        x5 = x3.add(x4)
        # Third residual block
        x6 = self.residual_block_3a(x5)
        x5 = self.parallel_conv_3(x5)
        x7 = x5.add(x6)
        
        out = x7.view(x7.size(0), -1) # Flatten
        out = self.last_block(out)
        
        return out


TRAINING LOOP


def compute_l2_reg(model,model_name):
    # Function that sets weight_decay only for weights and not biases and only 
    # for conv layers inside residual layers
    lambda_ = FLAGS.weight_decay
    params_dict = dict(model.named_parameters())
    l2_reg=[]  
    if model_name == 'resnet8':
        for key, value in params_dict.items():
            if ((key[-8:] == '2.weight' or key[-8:] == '5.weight') and key[0:8]=='residual'):
                l2_reg += [lambda_*torch.norm(value.view(value.size(0),-1),2)]
   
    l2_reg = sum(l2_reg)
    return l2_reg
    
def train_model(model, num_epochs, learning_rate, train_loader, valid_loader, 
                patience, model_name):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # To track the training loss as the model trains
    train_losses = []
    # To track the validation loss as the model trains
    valid_losses = []
    # To track the average training loss per epoch as the model trains
    avg_train_losses = []
    # To track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    # Initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    # Training loop
    decay = FLAGS.decay # Default 1e-5
    fcn = lambda step: 1./(1. + decay*step)
    scheduler = LambdaLR(optimizer, lr_lambda=fcn)
    
    for epoch in range(1, num_epochs+1):
        ###################
        # TRAIN the model #
        ###################
        model.train() # prep model for training
        for batch, (images, targets) in enumerate(train_loader, 1):
            # Load images and targets to device
            images = images.to(device)
            targets = targets.to(device)
            # Clear gradients
            optimizer.zero_grad()
            # Forward pass
            outputs = model(images)
            # Calculate loss
            l2_reg = compute_l2_reg(model,model_name)
            loss = F.mse_loss(outputs, targets) + l2_reg
            # Backward pass
            loss.backward()
            # Update weights
            optimizer.step()
            # Decay Learning Rate     
            scheduler.step()
            # Record training loss
            train_losses.append(loss.item())
            
        ######################    
        # VALIDATE the model #
        ######################
        model.eval() # prep model for evaluation
        for images, targets in valid_loader:
            images = images.to(device)
            targets = targets.to(device)
            # Forward pass:
            outputs = model(images)
            # Calculate loss
            loss = F.mse_loss(outputs, targets)
            # Record validation loss
            valid_losses.append(loss.item())

        # Print training/validation statistics 
        # Calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(num_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{num_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # Clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # Early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break    
    
    # Load the last checkpoint with the best model 
    # (returned by early_stopping call)
    model.load_state_dict(torch.load('checkpoint.pt'))
    print('Training completed and model saved.')

    return  model, avg_train_losses, avg_valid_losses

where weight_decay = 1e-4 and Early Stop is a function that checks for validation loss and stores the model with least validation error.

I tried to emulate in PyTorch same manual and default initializations as in Keras, you can see it in my Model code.

I also tried to exactly reproduce the same regularization as in Keras kernel regularizer.

@ptrblck I posted some code, I hope it may help in this discussion.

As mentioned before, I have constant training loss, constant validation loss and constant output prediction (it’s a regression problem). Maybe it’s a dying ReLU problem, but in Keras it doesn’t happen

After some further debugging I found out that the last relu of my trained model is dead. This happens only to the last one, previous to that one all of the others seems to be working. How would you suggest to solve it? I was thinking of trying leaky relu/initialization of biases of conv layers to 0.1/directly removing the last relu.

hi, how do you deal with this problem? II have met the same problem. Thanks for your reply.