Moving Keras to Pytorch - great performance differences!?!

Hi folks,
I try to “translate” a Keras NN to Pytorch. Unfortunately the performance difference is around 5% (Pytorch is worse). After spending weeks of work and browsing through former posts I could not solve the problem. I would be super thankful for help!

Keras Model:

    word_input = Input(shape=(maxlen,), dtype='int32', name='word_input') 
    type_input = Input(shape=(2,), dtype='int32', name='type_input')   
    word_input_dim = word_embedding_matrix.shape[0]   
    word_output_dim = word_embedding_matrix.shape[1] 
    
    word_embeddings_layer = Embedding(input_dim=word_input_dim, output_dim=word_output_dim, weights=[word_embedding_matrix], trainable=False, mask_zero=True)
    word_embeddings = word_embeddings_layer(word_input)

    type_embeddings_layer = Embedding(input_dim=num_types, output_dim=10)
    type_embeddings = type_embeddings_layer(type_input)

    time_distributed_layer = TimeDistributed(Dense(size_factor))  
    time_distributed = time_distributed_layer(word_embeddings)     

    lstm_layer = Bidirectional(LSTM(units=size_factor))   
    hidden = lstm_layer(time_distributed)

    type_conv_layer = Conv1D(filters=size_factor*2, kernel_size=2)  
    type_conv = type_conv_layer(type_embeddings)  
    type_pooled = GlobalAvgPool1D()(type_conv)                 

    global_repr =  Add()([hidden, type_pooled])          

    prediction_layer = Dense(num_classes, activation='softmax')   
    prediction = prediction_layer(global_repr)

    model = Model(inputs=[word_input, type_input], outputs=prediction)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    model.fit([train_word_matrix, train_type_matrix], train_labels,
              validation_data=([dev_word_matrix, dev_type_matrix], dev_labels), class_weight=weights, epochs=epochs)

Pytorch Model

class NeuralNet(nn.Module):
    def __init__(self, word_input_dim, word_output_dim, word_embedding_matrix, num_types, size_factor, num_classes):
        super(NeuralNet, self).__init__()

        self.word_input_dim = word_input_dim
        self.word_output_dim = word_output_dim
        self.word_embedding_matrix = word_embedding_matrix
        self.num_types = num_types
        self.size_factor = size_factor
        self.num_classes = num_classes

        self.word_embedding = nn.Embedding(word_input_dim, word_output_dim)#, padding_idx=0)
        self.word_embedding.weight = nn.Parameter(torch.tensor(word_embedding_matrix, dtype=torch.float32))
        self.word_embedding.weight.requires_grad = False

        self.type_embedding = nn.Embedding(num_types, 10)

        self.td_dense = nn.Linear(word_output_dim, size_factor)
        self.biLSTM = nn.LSTM(size_factor, size_factor, bidirectional=True, batch_first=True)
        self.LSTM = nn.LSTM(size_factor, size_factor, bidirectional=False, batch_first=True)

        self.conv1D = nn.Conv1d(10, size_factor*2, kernel_size=2)
        self.pooling = nn.AdaptiveAvgPool1d(1)

        self.predict = nn.Linear(size_factor*2, num_classes)


    def forward(self, x, x_rev, y, lens):
        word_embeddings = self.word_embedding(x)

        td_dense = self.td_dense(word_embeddings)
        x_packed = pack_padded_sequence(td_dense, lens, batch_first=True, enforce_sorted=False)

        biLSTM, (h_n, c_n) = self.biLSTM(x_packed)

        forward = h_n[0]  # h_n of forward LSTM
        backward = h_n[1]  # h_n of backward LSTM
        concat = torch.cat((forward, backward), dim=1)

        #type embeddings + pooling
        type_embeddings = self.type_embedding(y)
        type_embeddings = type_embeddings.transpose(1,2) 
        conv1D = self.conv1D(type_embeddings)
        types_pooled = self.pooling(conv1D)
        types_pooled = types_pooled.view(types_pooled.size()[:-1])

        global_repr = concat.add(types_pooled)
        final = self.predict(global_repr)

        return final

and the training loop:

 ### initialize model
    net = NetBen.NeuralNet(word_input_dim, word_output_dim, word_embedding_matrix, num_types, size_factor, num_classes)
    net.cuda()

    embedding_dim = 300
    learning_rate = 0.001
    criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='sum')
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    net.apply(initialize_weights)

    net.train()

    for e in range(epochs):

        # batch loop
        for inputs, rev_inputs, type_inputs, labels, lens in train_loader:
            counter += 1

            inputs, rev_inputs, type_inputs, labels, lens = inputs.cuda(), rev_inputs.cuda(), type_inputs.cuda(), labels.cuda(), lens.cuda()

            # zero accumulated gradients
            net.zero_grad()

            # get the output from the model
            output = net(inputs, rev_inputs, type_inputs, lens)

            # calculate the loss and perform backprop
            loss = criterion(output, labels)
            loss.backward()

            nn.utils.clip_grad_norm_(net.parameters(), clip)
            optimizer.step()

Initialization:

def initialize_weights(model):
    if type(model) in [nn.Linear]:
        nn.init.xavier_uniform_(model.weight)
        nn.init.zeros_(model.bias)
    elif type(model) in [nn.LSTM, nn.RNN, nn.GRU]:
        nn.init.xavier_uniform_(model.weight_hh_l0)
        nn.init.xavier_uniform_(model.weight_ih_l0)
        nn.init.zeros_(model.bias_hh_l0)
        nn.init.zeros_(model.bias_ih_l0)

I already initialized the Pytorch tensors in the same way Keras does.
I know, that Keras uses different activations for the LSTM and found some Pytorch way by https://github.com/huggingface/torchMoji/blob/master/torchmoji/lstm.py (unfortunately without gpu usage).
But if I used the huggingface implementation I got very similar results though and could not reach Keras performace.

I’m looking forward to input and helpful tips.

I have no answer, but I saw this issue is not alone. Another Embedding-Bidirectional-LSTM showed the same effect: https://www.kaggle.com/c/stanford-covid-vaccine/discussion/186485
I hope someone can detect why that is.