Converting LSTM model from Keras to PyTorch

I have a model developed in Keras that I wish to port over to PyTorch. The model is as such:

s = SGD(lr=learning['rate'], decay=0, momentum=0.5, nesterov=True)
m = keras.models.Sequential([
        keras.layers.LSTM(256, input_shape=(70, 256), activation='tanh', 
        return_sequences=True),
        keras.layers.LSTM(64, activation='tanh', return_sequences=True),
        keras.layers.LSTM(16, activation='tanh'),
        keras.layers.Dense(8, activation='softmax')
m.compile(loss='binary_crossentropy', optimizer=s, metrics=['accuracy'])

It is a simple 3-layer LSTM with an output layer for 8 classes. This is not a multi-label classification problem. I use a binary crossentropy loss function paired with an SGD optimizer. I have tried to reproduce this model in PyTorch like such:

class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm_1 = nn.LSTM((70, 256), 256, batch_first=True)
        self.lstm_2 = nn.LSTM(256, 64, batch_first=True)
        self.lstm_3 = nn.LSTM(64, 16, batch_first=True)
        self.output = nn.Linear(16, 8)

    def forward(self, x):
        x = self.lstm_1(x)[0].tanh()
        x = self.lstm_2(x)[0].tanh()
        x = self.lstm_3(x)[0].tanh()[:, -1, :]
        return F.softmax(self.output(x), dim=1)

I still want to use the same optimizer and loss function:

m = LSTMModel()
s = SGD(m.parameters(), lr=learning['rate'], weight_decay=0, momentum=0.5, nesterov=True)
loss_fx = BCELoss()

The model on PyTorch is significantly worse than the Keras implementation.

PyTorch:

 - 1565s - loss: 0.1637 - acc: 0.7035 - val_loss: 0.1451 - val_acc: 0.7441
 - 1672s - loss: 0.1472 - acc: 0.7288 - val_loss: 0.1437 - val_acc: 0.7430
 - 1851s - loss: 0.1467 - acc: 0.7288 - val_loss: 0.1432 - val_acc: 0.7430
 - 1612s - loss: 0.1462 - acc: 0.7288 - val_loss: 0.1422 - val_acc: 0.7430
 - 1650s - loss: 0.1409 - acc: 0.7288 - val_loss: 0.1326 - val_acc: 0.7430
 - 1460s - loss: 0.1307 - acc: 0.7288 - val_loss: 0.1313 - val_acc: 0.7430
 - 1455s - loss: 0.1286 - acc: 0.7288 - val_loss: 0.1301 - val_acc: 0.7430
 - 1458s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1291 - val_acc: 0.7430
 - 1458s - loss: 0.1278 - acc: 0.7288 - val_loss: 0.1289 - val_acc: 0.7430
 - 1452s - loss: 0.1279 - acc: 0.7288 - val_loss: 0.1287 - val_acc: 0.7430
 - 1439s - loss: 0.1279 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1473s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1601s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
 - 1442s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
 - 1487s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
 - 1444s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1455s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1436s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1448s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
 - 1441s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430

Keras:

 - 1140s - loss: 0.1402 - acc: 0.9363 - val_loss: 0.1382 - val_acc: 0.9407
 - 1184s - loss: 0.1185 - acc: 0.9453 - val_loss: 0.1234 - val_acc: 0.9409
 - 1121s - loss: 0.1114 - acc: 0.9493 - val_loss: 0.1312 - val_acc: 0.9341
 - 1109s - loss: 0.1055 - acc: 0.9533 - val_loss: 0.1138 - val_acc: 0.9475
 - 1110s - loss: 0.1032 - acc: 0.9547 - val_loss: 0.1158 - val_acc: 0.9480
 - 1104s - loss: 0.1029 - acc: 0.9549 - val_loss: 0.1134 - val_acc: 0.9485
 - 1120s - loss: 0.1030 - acc: 0.9549 - val_loss: 0.1098 - val_acc: 0.9497
 - 1134s - loss: 0.1032 - acc: 0.9548 - val_loss: 0.1077 - val_acc: 0.9509
 - 1173s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1067 - val_acc: 0.9515
 - 1124s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1062 - val_acc: 0.9518
 - 1125s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1061 - val_acc: 0.9519
 - 1128s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1112s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1134s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1179s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1144s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1130s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1183s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1121s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1106s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
 - 1109s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519

A bit of discrepancy between the two libraries is to be expected, but this difference appears bigger than that. Is something wrong with the current implementation?

Thanks,

1 Like

Just by skimming through the code, it seems you are using a wrong activation function for the criterion.
nn.BCELoss expects a sigmoid function applied on the model outputs.

However, you mentioned that his is not a multi-label classification, so I assume it’s a multi-class classification?
If that’s the case, remove the softmax/sigmoid and use nn.CrossEntropyLoss.

1 Like

Thank you for your response. I changed the loss function to CrossEntropyLoss and removed the softmax activation in the forward method, however, I am getting more or less the same results. It is around the same train/val accuracy as before (around 0.74), however, the loss is now considerably higher hovering around 0.55 as opposed to .12. I am thinking this has something to do with the LSTM layers. Do I have to pass randomized hidden states to the layers, or does PyTorch deal with this automatically. Unfortunately, a lot of the implementation details are abstracted away when using Keras, so I am not entirely sure how they deal with this either.

As explained in the docs, an nn.LSTM expects input, (hidden, cell) as the input. Since you are neither passing the hidden and cell state in the first layer nor using the output states, these should be initialized to zero tensors.

Do you know, how Keras handles this use case?

According to the source code:

    initial_state: List of initial state tensors to be passed to the first
      call of the cell (optional, defaults to `None` which causes creation
      of zero-filled initial state tensors).

So the hidden state is always zero-filled just like PyTorch. Am I possibly doing something wrong during the actual training phase? I have another model in PyTorch which I train the same way, so I don’t think this is the issue, but it may be the case here.

# function: train                                                                                                                                                                                                   
#                                                                                                                                                                                                                   
# arguments: tr_gen - data loader for training data                                                                                                                                                                 
#            cv_gen - data loader for cross_val data                                                                                                                                                                
#            model - the pytorch model                                                                                                                                                                              
#            opt - the optimizer for the model                                                                                                                                                                      
#            loss_fx - the loss function used to backprop                                                                                                                                                           
#            device - cpu/gpu device used for operations                                                                                                                                                            
#                                                                                                                                                                                                                   
# return: history - acc/loss dict for training and cv                                                                                                                                                               
#                                                                                                                                                                                                                   
# This method trains on the dataset for one epoch                                                                                                                                                                   
#
def train(tr_gen, cv_gen, model, opt, loss_fx, device):

    # set data loaders to a dictionary                                                                                                                                                                              
    #                                                                                                                                                                                                               
    data_loaders = {TRAIN: tr_gen, EVAL: cv_gen}

    # set losses for each phase to a dictionary                                                                                                                                                                     
    #                                                                                                                                                                                                               
    phase_losses = {TRAIN: 0.0, EVAL: 0.0}

    # set accuracies for each phase to a dictionary                                                                                                                                                                 
    #                                                                                                                                                                                                               
    phase_acc = {TRAIN: 0.0, EVAL: 0.0}

    # time an epoch                                                                                                                                                                                                 
    #                                                                                                                                                                                                               
    start_time = time.time()

    # for each phase (train/cv)                                                                                                                                                                                     
    #                                                                                                                                                                                                               
    for phase in [TRAIN, EVAL]:

        # set the approriate mode for the model                                                                                                                                                                     
        #                                                                                                                                                                                                           
        if phase == TRAIN:
            model.train(True)
        else:
            model.eval()
        # collect the loss average over batches                                                                                                                                                                     
        #                                                                                                                                                                                                           
        running_loss = 0.0

        # collect the number of correct classifications                                                                                                                                                             
        #                                                                                                                                                                                                           
        correct = 0

        # total number of inputs to the model                                                                                                                                                                       
        #                                                                                                                                                                                                           
        total = 0

        # collect number of batches                                                                                                                                                                                 
        #                                                                                                                                                                                                           
        num_batches = len(data_loaders[phase])

        # get the batch size                                                                                                                                                                                        
        #                                                                                                                                                                                                           
        batch_size = data_loaders[phase].dataset.frames_per_minibatch

        # iterate over the dataset                                                                                                                                                                                  
        #                                                                                                                                                                                                           
        for index, batch in enumerate(data_loaders[phase]):

            # zero out the optimizer gradient                                                                                                                                                                       
            #                                          
            opt.zero_grad()

            # collect the data and labels                                                                                                                                                                           
            #                                                                                                                                                                                                       
            data = batch[0].squeeze().to(device)
            labels = batch[1].squeeze().max(dim=1)[1].to(device)
            batch_size = data.shape[0]

            # feed the data to the model                                                                                                                                                                            
            #                                                                                                                                                                                                       
            output = model(data)

            # calculate the loss                                                                                                                                                                                    
            #                                                                                                                                                                                                       
            loss = loss_fx(output, labels)
            running_loss += loss.item()

            # increment the number of input to the model                                                                                                                                                            
            #                                                                                                                                                                                                       
            total += labels.shape[0]

            # count the times where the label index == output max index                                                                                                                                             
            #                                                                                                                                                                                                       
            correct += torch.sum(labels == torch.max(output.data, dim=1)[1].to(device))

            # if we are training                                                                                                                                                                                    
            #                                       
            if phase == TRAIN:

                # perform backprop and take a step for the optimizer                                                                                                                                                
                #                                                                                                                                                                                                   
                loss.backward()
                opt.step()

        # calculate the average loss for the dataset                                                                                                                                                                
        #                                                                                                                                                                                                           
        phase_losses[phase] = running_loss / float(num_batches)

        # calculate the accuracy for all input                                                                                                                                                                      
        #                                                                                                                                                                                                           
        phase_acc[phase] = float(correct) / float(total)

    # collect the loss and accuracy                                                                                                                                                                                 
    #                                                                                                                                                                                                               
    history = {LOSS: phase_losses, ACC: phase_acc}

    # calculate the time it took to train                                                                                                                                                                           
    #                                                                                                                                                                                                               
    total_time = int(time.time() - start_time)

    # collect loss/acc for training and cv                                                                                                                                                                          
    #                                                                                                                                                                                                               
    loss = phase_losses[TRAIN]
    acc = phase_acc[TRAIN]
    val_loss = phase_losses[EVAL]
    val_acc = phase_acc[EVAL]

    # print informational message                                                                                                                                                                                   
    #                                                                                                                                                                                                               
    print(" - %ds - loss: %.4f - acc: %.4f - val_loss: %.4f - val_acc: %.4f"
          % (total_time, loss, acc, val_loss, val_acc))

    # return the history                                                                                                                                                                                            
    #                                                                                                                                                                                                               
    return history
#                                                                                                                                                                                                                   
# end of function

It should be noted that the way I train the model is by setting a learning rate initially and manually reducing it by using this function:

# function: update_lr                                                                                                                                                                                               
#                                                                                                                                                                                                                   
# arguments: optim - the adam optimizer                                                                                                                                                                             
#            lr - the new learning rate                                                                                                                                                                             
#                                                                                                                                                                                                                   
# return: none                                                                                                                                                                                                      
#                                                                                                                                                                                                                   
# This method updates the learning rate                                                                                                                                                                             
#                                                                                                                                                                                                                   
def update_lr(optim, lr):

    # update the learning rate for each param group                                                                                                                                                                 
    #                                                                                                                                                                                                               
    for param_group in optim.param_groups:
        param_group['lr'] = lr
#                                                                                                                                                                                                                   
# end of function 

The learning rate is halved each epoch once the validation error after the current epoch is at least >= min_val_error, which is just a custom set parameter. The same technique is used for the Keras model.

Thanks again for your help!