I have a model developed in Keras that I wish to port over to PyTorch. The model is as such:
s = SGD(lr=learning['rate'], decay=0, momentum=0.5, nesterov=True)
m = keras.models.Sequential([
keras.layers.LSTM(256, input_shape=(70, 256), activation='tanh',
return_sequences=True),
keras.layers.LSTM(64, activation='tanh', return_sequences=True),
keras.layers.LSTM(16, activation='tanh'),
keras.layers.Dense(8, activation='softmax')
m.compile(loss='binary_crossentropy', optimizer=s, metrics=['accuracy'])
It is a simple 3-layer LSTM with an output layer for 8 classes. This is not a multi-label classification problem. I use a binary crossentropy loss function paired with an SGD optimizer. I have tried to reproduce this model in PyTorch like such:
class LSTMModel(nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm_1 = nn.LSTM((70, 256), 256, batch_first=True)
self.lstm_2 = nn.LSTM(256, 64, batch_first=True)
self.lstm_3 = nn.LSTM(64, 16, batch_first=True)
self.output = nn.Linear(16, 8)
def forward(self, x):
x = self.lstm_1(x)[0].tanh()
x = self.lstm_2(x)[0].tanh()
x = self.lstm_3(x)[0].tanh()[:, -1, :]
return F.softmax(self.output(x), dim=1)
I still want to use the same optimizer and loss function:
m = LSTMModel()
s = SGD(m.parameters(), lr=learning['rate'], weight_decay=0, momentum=0.5, nesterov=True)
loss_fx = BCELoss()
The model on PyTorch is significantly worse than the Keras implementation.
PyTorch:
- 1565s - loss: 0.1637 - acc: 0.7035 - val_loss: 0.1451 - val_acc: 0.7441
- 1672s - loss: 0.1472 - acc: 0.7288 - val_loss: 0.1437 - val_acc: 0.7430
- 1851s - loss: 0.1467 - acc: 0.7288 - val_loss: 0.1432 - val_acc: 0.7430
- 1612s - loss: 0.1462 - acc: 0.7288 - val_loss: 0.1422 - val_acc: 0.7430
- 1650s - loss: 0.1409 - acc: 0.7288 - val_loss: 0.1326 - val_acc: 0.7430
- 1460s - loss: 0.1307 - acc: 0.7288 - val_loss: 0.1313 - val_acc: 0.7430
- 1455s - loss: 0.1286 - acc: 0.7288 - val_loss: 0.1301 - val_acc: 0.7430
- 1458s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1291 - val_acc: 0.7430
- 1458s - loss: 0.1278 - acc: 0.7288 - val_loss: 0.1289 - val_acc: 0.7430
- 1452s - loss: 0.1279 - acc: 0.7288 - val_loss: 0.1287 - val_acc: 0.7430
- 1439s - loss: 0.1279 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
- 1473s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
- 1601s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
- 1442s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
- 1487s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1285 - val_acc: 0.7430
- 1444s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
- 1455s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
- 1436s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
- 1448s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
- 1441s - loss: 0.1280 - acc: 0.7288 - val_loss: 0.1286 - val_acc: 0.7430
Keras:
- 1140s - loss: 0.1402 - acc: 0.9363 - val_loss: 0.1382 - val_acc: 0.9407
- 1184s - loss: 0.1185 - acc: 0.9453 - val_loss: 0.1234 - val_acc: 0.9409
- 1121s - loss: 0.1114 - acc: 0.9493 - val_loss: 0.1312 - val_acc: 0.9341
- 1109s - loss: 0.1055 - acc: 0.9533 - val_loss: 0.1138 - val_acc: 0.9475
- 1110s - loss: 0.1032 - acc: 0.9547 - val_loss: 0.1158 - val_acc: 0.9480
- 1104s - loss: 0.1029 - acc: 0.9549 - val_loss: 0.1134 - val_acc: 0.9485
- 1120s - loss: 0.1030 - acc: 0.9549 - val_loss: 0.1098 - val_acc: 0.9497
- 1134s - loss: 0.1032 - acc: 0.9548 - val_loss: 0.1077 - val_acc: 0.9509
- 1173s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1067 - val_acc: 0.9515
- 1124s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1062 - val_acc: 0.9518
- 1125s - loss: 0.1033 - acc: 0.9549 - val_loss: 0.1061 - val_acc: 0.9519
- 1128s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1112s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1134s - loss: 0.1033 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1179s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1144s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1130s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1183s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1121s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1106s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
- 1109s - loss: 0.1032 - acc: 0.9550 - val_loss: 0.1061 - val_acc: 0.9519
A bit of discrepancy between the two libraries is to be expected, but this difference appears bigger than that. Is something wrong with the current implementation?
Thanks,