PyTorch model ported from Keras model not learning

I am porting a net from Keras to Pytorch, however the training in Pytorch doesn’t seem to learn anything. The net is trying to learn sub-word embeddings in a phrase and classify it among three classes.

The architecture (along with output shapes) of the net is as follows :

  1. Embedding Layer (batch x 200 x 128)
  2. Convolution Layer (batch x 198 x 128)
  3. Max Pooling Layer (batch x 66 x 128)
  4. LSTM Layer (batch x 66 x 128)
  5. LSTM Layer (batch x 128)
  6. Dense Layer (batch x 3)

I have tried with different learning rates, changing the loss from Cross Entropy to Softmax+NLL, however the Pytorch model doesn’t seem to learn.

The input X is a file of shape (num_samples, 200) with each value in the range [0,26], and y are the labels in [0,1,2] for each sample.

Here is the Keras code :

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Embedding, LSTM, GRU, Convolution1D, MaxPooling1D
from tensorflow.keras.utils import to_categorical

#Format conversion for y
y_train = to_categorical(y, 3) 
#Train & Validation data splitting
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=42)

embedLayer = Embedding(input_dim=27, output_dim=128, input_length=200)
convLayer = Convolution1D(filters=128, kernel_size=3, activation='relu')
poolLayer = MaxPooling1D(pool_size=3)
lstmLayer1 = LSTM(units=128, dropout=0.2, recurrent_dropout=0.2, return_sequences=True)
lstmLayer2 = LSTM(units=128, dropout=0.2, recurrent_dropout=0.2, return_sequences=False)
denseLayer = Dense(units=3)

model = Sequential()
model.add(embedLayer)
model.add(convLayer)
model.add(poolLayer)
model.add(lstmLayer1)
model.add(lstmLayer2)
model.add(denseLayer)
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adamax',
              metrics=['accuracy'])

# Training data
model.fit(x=X_train, y=y_train, 
          batch_size=128, 
          epochs=50,
          validation_data=(X_valid, y_valid))

Here is the PyTorch code :

class RNNModel(nn.Module):
    
    def __init__(self):
        
        super(RNNModel, self).__init__()
            
        # Layers
        self.embeddingLayer = nn.Embedding(num_embeddings=27, embedding_dim=128)
        self.convLayer = nn.Conv1d(in_channels=128, out_channels=128, 
                                   kernel_size=3)
        self.lstmLayer1 = nn.LSTM(input_size=128, hidden_size=128)
        self.lstmLayer2 = nn.LSTM(input_size=128, hidden_size=128)
        self.denseLayer = nn.Linear(in_features=128, out_features=3)
        
    def forward(self, x):
        
        # Forming embeddings
        x = self.embeddingLayer(x)
        
        # Convolution and pooling
        x = x.view(-1, 128, 200)
        x = F.relu(self.convLayer(x))
        x = F.max_pool1d(x, kernel_size=3)
        
        # LSTM layers
        x = x.view(x.shape[2], x.shape[0], x.shape[1])
        x, _ = self.lstmLayer1(x)
        x = F.dropout(x, p=0.2)
        _, (x, _) = self.lstmLayer2(x)
        x = F.dropout(x, p=0.2)
        
        # Dense layer
        x = x.view(x.shape[1], x.shape[2])
        x = self.denseLayer(x)
        
        return x
        
device = torch.device('cuda')

model = RNNModel()
model = model.to(device)

lossFunc = nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(model.parameters())

# trainGen and devGen are generators for training and dev set respectively

for i in range(50):
    for X_batch, y_batch in trainGen:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        pred_batch = model(X_batch.long())
        loss = lossFunc(pred_batch, y_batch.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("\nITERATION ", i+1, "\n Train Loss = ", loss.item())
    
    with torch.no_grad():
        for X_dev, y_dev in devGen:
            X_dev = X_dev.to(device)
            y_dev = y_dev.to(device)
            pred_batch_dev = model(X_dev.long())
            loss = lossFunc(pred_batch_dev, y_dev.long())
            print(" Dev Loss = ", loss.item())

The Keras model does learn something because it brings up the accuracy from around 47% to 63%. However the Pytorch models keeps wavering around the initial loss only. What is the problem here?

Could you explain, what dropout=0.2 and recurrent_dropout=0.2 do in your Keras model?
It seems that you are using a single F.dropout on the outputs of both LSTM layers.

Also, since you are using the functional dropout API, you should pass the training=self.training argument to disable it during validation. :wink:

From the keras docs, dropout is for the linear transformation of the inputs and recurrent_dropout is for the linear transformation of the recurrent states.
I’m guessing keras dropout masks the inputs, hence that should simply be implemented by adding a F.dropout layer before each LSTM in PyTorch (I’ve mistakenly added them at the outputs, I’ll change it right away)
I’m not able to understand how to implement the recurrent_dropout part in PyTorch, which masks the connections between the recurrent units. Is it possible that absence of this particular dropout is the only reason why the PyTorch model is not learning while Keras one is?

Oh yes, my bad. I’ll change it right away.

Thanks a lot!

So I read up a little and it seems that recurrent dropout is not yet implemented in PyTorch, and there is an open Github issue for the same. I guess the only solution is to manually write the LSTM cell then.
However, shouldn’t absence of such a dropout be responsible for overfitting at max?
I still can’t wrap my head around the fact that this is the sole reason why the model isn’t learning at all (neither the training nor the dev loss seem to decrease at all).

That might be the case and I agree that something else might create the issue.

In the Keras implementation you are using return_sequences=False.
Based on the docs, it seems that this LSTM will only return the output of the last timestep.

In your PyTorch implementation you are reshaping the output of self.lstmLayer2 via:

x = x.view(x.shape[1], x.shape[2])

which is not using the last timestep, but the complete sequence.

That being said, I think the reshaping in general might be wrong.

The last pooling layer will return an activation in the shape [batch_size, channels, seq_len].
Since the nn.LSTM will accept an input of [seq_len, batch_size, nb_features] by default (if batch_first=False).
It seems you would like to permute the activation here:

x = x.view(x.shape[2], x.shape[0], x.shape[1])

which won’t work, since you are interleaving the signal.
Use x = x.permute(2, 0, 1) instead to permute the dimensions.

The same applies for the output of self.lstmLayer2.

Let me know, if that helps.

But I’m taking the output of the second LSTM as _, (x, _) = self.lstmLayer2(x), which according to the docs produces output, (h_n, c_n) that is the output of the last timestep.
The x = x.view(x.shape[1], x.shape[2]) after the second LSTM was to reshape (seq_len, batch) into (batch, seq_len) to be fed into the dense layer.

Yes! I changed the view reshaping before convLayer and lstmLayer1 to permute and that got it working. Thanks a lot!
This works perfectly now :

    def forward(self, x):
        
        # Forming embeddings
        x = self.embeddingLayer(x)
        
        # Convolution and pooling
        x = x.permute(0, 2, 1)
        x = F.relu(self.convLayer(x))
        x = F.max_pool1d(x, kernel_size=pool_length)
        
        # LSTM layers
        x = x.permute(2, 0, 1)
        x = F.dropout(x, p=0.2, training=self.training)
        x, _ = self.lstmLayer1(x)
        x = F.dropout(x, p=0.2, training=self.training)
        _, (x, _) = self.lstmLayer2(x)
        
        # Dense layer
        x = x.view(x.shape[1], x.shape[2])
        x = self.denseLayer(x)
        
        return x

But I can’t understand why view doesn’t work here. What do you mean by interleaving the signal?

Thanks for the correction regarding the last state / output, as I’ve clearly mixed up the return values. :slight_smile:

Have a look at this example:

 x = torch.arange(2*4*5).view(2, 4, 5)
print(x)

# Now permute the dimensions
print(x.permute(0, 2, 1)) # rows and columns for each sample in dim0 are swapped

# Wrong permutation!
print(x.view(2, 5, 4)) # result is interleaved

Here you can see, that I would like to swap the last two dimensions (basically transpose dim1 and dim2 for each sample in the batch dimension).
.permute does the job right, while view just changes the stride and shape and thus will create a wrong and interleaved result.
By “interleaved” I mean that the original row tensors are cut and the remaining elements are added to the next row.

Oh alright, got the point!
Thanks a lot for the help! :slight_smile: