Converting LSTM model from Keras to PyTorch

I have a model developed in Keras that I wish to port over to PyTorch. The model is as such:

s = SGD(lr=learning['rate'], decay=0, momentum=0.5, nesterov=True)
m = keras.models.Sequential([
        keras.layers.LSTM(256, input_shape=(70, 256), activation='tanh', 
        keras.layers.LSTM(64, activation='tanh', return_sequences=True),
        keras.layers.LSTM(16, activation='tanh'),
        keras.layers.Dense(8, activation='softmax')
m.compile(loss='binary_crossentropy', optimizer=s, metrics=['accuracy'])

It is a simple 3-layer LSTM with an output layer for 8 classes. This is not a multi-label classification problem. I use a binary crossentropy loss function paired with an SGD optimizer. I have tried to reproduce this model in PyTorch like such:

class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm_1 = nn.LSTM((70, 256), 256, batch_first=True)
        self.lstm_2 = nn.LSTM(256, 64, batch_first=True)
        self.lstm_3 = nn.LSTM(64, 16, batch_first=True)
        self.output = nn.Linear(16, 8)

    def forward(self, x):
        x = self.lstm_1(x)[0].tanh()
        x = self.lstm_2(x)[0].tanh()
        x = self.lstm_3(x)[0].tanh()[:, -1, :]
        return F.softmax(self.output(x), dim=1)

I still want to use the same optimizer and loss function:

m = LSTMModel()
s = SGD(m.parameters(), lr=learning['rate'], weight_decay=0, momentum=0.5, nesterov=True)
loss_fx = BCELoss()

The model on PyTorch is significantly worse than the Keras implementation.


 - 1565s - loss: 0.1637 - acc: 0.7035 - val_loss: 0.1451 - val_acc: 0.7441
 - 1672s - loss: 0.1472 - acc: 0.7288 - val_loss: 0.1437 - val_acc: 0.7430
 - 1851s - loss: 0.1467 - acc: 0.7288 - val_loss: 0.1432 - val_acc: 0.7430
 - 1612s - loss: 0.1462 - acc: 0.7288 - val_loss: 0.1422 - val_acc: 0.7430
 - 1650s - loss: 0.1409 - acc: 0.7288 - val_loss: 0.1326 - val_acc: 0.7430
 - 1460s - loss: 0.1307 - acc: 0.7288 - val_loss: 0.1313 - val_acc: 0.7430
 - 1455s - loss: 0.1286 - acc: 0.7288 - val_loss: 0.1301 - val_acc: 0.7430
 - 1458s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1291 - val_acc: 0.7430
 - 1458s - loss: 0.1278 - acc: 0.7288 - val_loss: 0.1289 - val_acc: 0.7430
 - 1452s - loss: 0.1279 - acc: 0.7288 - val_loss: 0.1287 - val_acc: 0.7430
 - 1439s - loss: 0.1279 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1473s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1601s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
 - 1442s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
 - 1487s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
 - 1444s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1455s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1436s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1448s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1441s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430


 - 1140s - loss: 0.1402 - acc: 0.9363 - val_loss: 0.1382 - val_acc: 0.9407
 - 1184s - loss: 0.1185 - acc: 0.9453 - val_loss: 0.1234 - val_acc: 0.9409
 - 1121s - loss: 0.1114 - acc: 0.9493 - val_loss: 0.1312 - val_acc: 0.9341
 - 1109s - loss: 0.1055 - acc: 0.9533 - val_loss: 0.1138 - val_acc: 0.9475
 - 1110s - loss: 0.1032 - acc: 0.9547 - val_loss: 0.1158 - val_acc: 0.9480
 - 1104s - loss: 0.1029 - acc: 0.9549 - val_loss: 0.1134 - val_acc: 0.9485
 - 1120s - loss: 0.1030 - acc: 0.9549 - val_loss: 0.1098 - val_acc: 0.9497
 - 1134s - loss: 0.1032 - acc: 0.9548 - val_loss: 0.1077 - val_acc: 0.9509
 - 1173s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1067 - val_acc: 0.9515
 - 1124s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1062 - val_acc: 0.9518
 - 1125s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1061 - val_acc: 0.9519
 - 1128s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1112s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1134s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1179s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1144s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1130s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1183s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1121s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1106s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1109s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519

A bit of discrepancy between the two libraries is to be expected, but this difference appears bigger than that. Is something wrong with the current implementation?


1 Like

Just by skimming through the code, it seems you are using a wrong activation function for the criterion.
nn.BCELoss expects a sigmoid function applied on the model outputs.

However, you mentioned that his is not a multi-label classification, so I assume it’s a multi-class classification?
If that’s the case, remove the softmax/sigmoid and use nn.CrossEntropyLoss.

1 Like

Thank you for your response. I changed the loss function to CrossEntropyLoss and removed the softmax activation in the forward method, however, I am getting more or less the same results. It is around the same train/val accuracy as before (around 0.74), however, the loss is now considerably higher hovering around 0.55 as opposed to .12. I am thinking this has something to do with the LSTM layers. Do I have to pass randomized hidden states to the layers, or does PyTorch deal with this automatically. Unfortunately, a lot of the implementation details are abstracted away when using Keras, so I am not entirely sure how they deal with this either.

As explained in the docs, an nn.LSTM expects input, (hidden, cell) as the input. Since you are neither passing the hidden and cell state in the first layer nor using the output states, these should be initialized to zero tensors.

Do you know, how Keras handles this use case?

According to the source code:

    initial_state: List of initial state tensors to be passed to the first
      call of the cell (optional, defaults to `None` which causes creation
      of zero-filled initial state tensors).

So the hidden state is always zero-filled just like PyTorch. Am I possibly doing something wrong during the actual training phase? I have another model in PyTorch which I train the same way, so I don’t think this is the issue, but it may be the case here.

# function: train                                                                                                                                                                                                   
# arguments: tr_gen - data loader for training data                                                                                                                                                                 
#            cv_gen - data loader for cross_val data                                                                                                                                                                
#            model - the pytorch model                                                                                                                                                                              
#            opt - the optimizer for the model                                                                                                                                                                      
#            loss_fx - the loss function used to backprop                                                                                                                                                           
#            device - cpu/gpu device used for operations                                                                                                                                                            
# return: history - acc/loss dict for training and cv                                                                                                                                                               
# This method trains on the dataset for one epoch                                                                                                                                                                   
def train(tr_gen, cv_gen, model, opt, loss_fx, device):

    # set data loaders to a dictionary                                                                                                                                                                              
    data_loaders = {TRAIN: tr_gen, EVAL: cv_gen}

    # set losses for each phase to a dictionary                                                                                                                                                                     
    phase_losses = {TRAIN: 0.0, EVAL: 0.0}

    # set accuracies for each phase to a dictionary                                                                                                                                                                 
    phase_acc = {TRAIN: 0.0, EVAL: 0.0}

    # time an epoch                                                                                                                                                                                                 
    start_time = time.time()

    # for each phase (train/cv)                                                                                                                                                                                     
    for phase in [TRAIN, EVAL]:

        # set the approriate mode for the model                                                                                                                                                                     
        if phase == TRAIN:
        # collect the loss average over batches                                                                                                                                                                     
        running_loss = 0.0

        # collect the number of correct classifications                                                                                                                                                             
        correct = 0

        # total number of inputs to the model                                                                                                                                                                       
        total = 0

        # collect number of batches                                                                                                                                                                                 
        num_batches = len(data_loaders[phase])

        # get the batch size                                                                                                                                                                                        
        batch_size = data_loaders[phase].dataset.frames_per_minibatch

        # iterate over the dataset                                                                                                                                                                                  
        for index, batch in enumerate(data_loaders[phase]):

            # zero out the optimizer gradient                                                                                                                                                                       

            # collect the data and labels                                                                                                                                                                           
            data = batch[0].squeeze().to(device)
            labels = batch[1].squeeze().max(dim=1)[1].to(device)
            batch_size = data.shape[0]

            # feed the data to the model                                                                                                                                                                            
            output = model(data)

            # calculate the loss                                                                                                                                                                                    
            loss = loss_fx(output, labels)
            running_loss += loss.item()

            # increment the number of input to the model                                                                                                                                                            
            total += labels.shape[0]

            # count the times where the label index == output max index                                                                                                                                             
            correct += torch.sum(labels == torch.max(, dim=1)[1].to(device))

            # if we are training                                                                                                                                                                                    
            if phase == TRAIN:

                # perform backprop and take a step for the optimizer                                                                                                                                                

        # calculate the average loss for the dataset                                                                                                                                                                
        phase_losses[phase] = running_loss / float(num_batches)

        # calculate the accuracy for all input                                                                                                                                                                      
        phase_acc[phase] = float(correct) / float(total)

    # collect the loss and accuracy                                                                                                                                                                                 
    history = {LOSS: phase_losses, ACC: phase_acc}

    # calculate the time it took to train                                                                                                                                                                           
    total_time = int(time.time() - start_time)

    # collect loss/acc for training and cv                                                                                                                                                                          
    loss = phase_losses[TRAIN]
    acc = phase_acc[TRAIN]
    val_loss = phase_losses[EVAL]
    val_acc = phase_acc[EVAL]

    # print informational message                                                                                                                                                                                   
    print(" - %ds - loss: %.4f - acc: %.4f - val_loss: %.4f - val_acc: %.4f"
          % (total_time, loss, acc, val_loss, val_acc))

    # return the history                                                                                                                                                                                            
    return history
# end of function

It should be noted that the way I train the model is by setting a learning rate initially and manually reducing it by using this function:

# function: update_lr                                                                                                                                                                                               
# arguments: optim - the adam optimizer                                                                                                                                                                             
#            lr - the new learning rate                                                                                                                                                                             
# return: none                                                                                                                                                                                                      
# This method updates the learning rate                                                                                                                                                                             
def update_lr(optim, lr):

    # update the learning rate for each param group                                                                                                                                                                 
    for param_group in optim.param_groups:
        param_group['lr'] = lr
# end of function 

The learning rate is halved each epoch once the validation error after the current epoch is at least >= min_val_error, which is just a custom set parameter. The same technique is used for the Keras model.

Thanks again for your help!