Hello,
First of all I appreciate that the two frameworks are different and cannot be expected to replicate results. I feel my issue here is that I am not correctly training the model or wiring up the nodes correctly, because according to me they use similar building blocks that are not extremely different and the Pytorch code should learn something regardless of weight initialization. I have tried a bunch of things and will try to elaborate on what I’ve tried below:
-
I find that the Pytorch model starts off with a similar loss and initial accuracy for both the train set and the validation sets, but whereas the Keras model begins to improve in validation and training accuracy after 25-30 epochs, the Pytorch model seems to not improve more than fractionally even for 100 epochs. The initial losses and accuracies give me some hope that the model definition is somewhat correct and maybe there is an issue with the training loop
-
I have manually computed the paddings as “same” padding is not available in Pytorch yet (fingers crossed for 1.9 ) and used integer labels instead of 1-Hot encoding as in the reference
-
The shapes seem to match Keras shapes at all layers (of course the filter dimension precedes the timesteps x features dimension where the convention differs)
-
The pytorch LSTM module seems to have less learnable parameters than the keras LSTM, (About 30k for T=50 timsteps) does this mean the keras LSTM is a larger abstraction of a basic LSTM with additional layers? I do believe the implementations of the LSTM formulae cannot differ by that much.
-
The losses change fractionally and the gradients do change but it appears that there may be a vanishing gradients problem somewhere as they are especially small for some of the conv layers
-
Do I need to detach certain parts of the LSTM if I am only using the last feature map in the output?
Trend on Reference (T=50)
Reference to Keras Model :-
# Params : T : Number of time steps
# NF : Number features
# number_of_lstm : LSTM Features
input_lmd = Input(shape=(T, NF, 1))
# build the convolutional block
conv_first1 = Conv2D(32, (1, 2), strides=(1, 2))(input_lmd)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (1, 2), strides=(1, 2))(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (1, 10))(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
# build the inception module
convsecond_1 = Conv2D(64, (1, 1), padding='same')(conv_first1)
convsecond_1 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_1)
convsecond_1 = Conv2D(64, (3, 1), padding='same')(convsecond_1)
convsecond_1 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_1)
convsecond_2 = Conv2D(64, (1, 1), padding='same')(conv_first1)
convsecond_2 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_2)
convsecond_2 = Conv2D(64, (5, 1), padding='same')(convsecond_2)
convsecond_2 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_2)
convsecond_3 = MaxPooling2D((3, 1), strides=(1, 1), padding='same')(conv_first1)
convsecond_3 = Conv2D(64, (1, 1), padding='same')(convsecond_3)
convsecond_3 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_3)
convsecond_output = keras.layers.concatenate([convsecond_1, convsecond_2, convsecond_3], axis=3)
conv_reshape = Reshape((int(convsecond_output.shape[1]), int(convsecond_output.shape[3])))(convsecond_output)
# build the last LSTM layer
conv_lstm = LSTM(number_of_lstm)(conv_reshape)
# build the output layer
out = Dense(3, activation='softmax')(conv_lstm)
model = Model(inputs=input_lmd, outputs=out)
adam = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1)
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
Pytorch Model
def conv_2d(input_filters, output_filters, kernel_size, padding=0, stride=1):
return nn.Sequential(
nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, padding=padding, stride=stride),
nn.LeakyReLU(inplace=True),
)
def init_weights(m):
if type(m) in [nn.Linear, nn.Conv2d]:
torch.nn.init.xavier_uniform_(m.weight)
if type(m) == nn.Linear:
m.bias.data.fill_(0)
class DeepLOB(nn.Module):
def __init__(self, T, NF, no_lstm, input_filters=1):
super().__init__()
# Initial Convolution Layers
self.conv_first1_1 = conv_2d(input_filters, 32, (1,2), stride=(1,2))
self.conv_first1_2 = conv_2d(32, 32, (4,1))
self.conv_first1_3 = conv_2d(32, 32, (4,1))
self.conv_first1_4 = conv_2d(32, 32, (1,2), stride=(1,2))
self.conv_first1_5 = conv_2d(32, 32, (4,1))
self.conv_first1_6 = conv_2d(32, 32, (4,1))
self.conv_first1_7 = conv_2d(32, 32, (1,10))
self.conv_first1_8 = conv_2d(32, 32, (4,1))
self.conv_first1_9 = conv_2d(32, 32, (4,1))
# "Inception Module" as implemented in reference
self.incept1 = conv_2d(32, 64, (1,1))
self.incept2 = conv_2d(64, 64, (3,1), padding=(1, 0))
self.incept3 = conv_2d(32, 64, (1,1))
self.incept4 = conv_2d(64, 64, (5,1), padding=(2, 0))
self.incept5 = nn.MaxPool2d((3,1), stride=(1,1), padding=(1,0))
self.incept6 = conv_2d(32, 64, (1,1))
# # build the last LSTM layer
self.conv_lstm = nn.LSTM(T, no_lstm, batch_first=True)
self.fc = nn.Linear(no_lstm, 3)
def forward(self, x):
out = self.conv_first1_1(x)
out = self.conv_first1_2(F.pad(out, (0, 0, 1, 2)))
out = self.conv_first1_3(F.pad(out, (0, 0, 1, 2)))
out = self.conv_first1_4(out)
out = self.conv_first1_5(F.pad(out, (0, 0, 1, 2)))
out = self.conv_first1_6(F.pad(out, (0, 0, 1, 2)))
out = self.conv_first1_7(out)
out = self.conv_first1_8(F.pad(out, (0, 0, 1, 2)))
out = self.conv_first1_9(F.pad(out, (0, 0, 1, 2)))
incept1 = self.incept1(out)
incept1 = self.incept2(incept1)
incept2 = self.incept3(out)
incept2 = self.incept4(incept2)
incept3 = self.incept5(out)
incept3 = self.incept6(incept3)
cat_layer = torch.cat([incept1, incept2, incept3], axis=1)
reshape = cat_layer.view(cat_layer.shape[0:3])
lstm_out, __ = self.conv_lstm(reshape)
lstm_out = lstm_out[:, -1, :]
return self.fc(lstm_out)
Training Keras
#Leaving out the data loading specifics
testX_CNN (Shape of [N x T x NF x 1 ])
trainY_CNN [Same as above except in batch dimesion]
trainY_CNN = np_utils.to_categorical(trainY_CNN, 3)
testY_CNN = np_utils.to_categorical(testY_CNN, 3)
model.fit(trainX_CNN, trainY_CNN, epochs=200, batch_size=64, verbose=2, validation_data=(testX_CNN, testY_CNN))
Training Pytorch
# Reshape as Pytorch convolution layers expect filter dimension first
# Avoid 1 hot as Pytorch CrossEntropy loss works with integer labels
# Create TensorDataset and create a loader from it
trainX_CNN = torch.Tensor(trainX_CNN).reshape([-1, trainX_CNN.shape[3], trainX_CNN.shape[1], trainX_CNN.shape[2]]) # transform to torch tensor
trainY_CNN = torch.Tensor(trainY_CNN).long()
testX_CNN = torch.Tensor(testX_CNN).reshape([-1, testX_CNN.shape[3], testX_CNN.shape[1], testX_CNN.shape[2]]) # transform to torch tensor
testY_CNN = torch.Tensor(testY_CNN).long()
dataset = TensorDataset(trainX_CNN, trainY_CNN)
dataset_val = TensorDataset(testX_CNN, testY_CNN)
dataloader = DataLoader(dataset, batch_size=64, num_workers=2, shuffle=True)
validation_loader = DataLoader(dataset_val, batch_size=64, num_workers=2)
model = DeepLOB(T, NF, no_of_lstm)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.apply(init_weights)
epochs = 200
training_losses = []
validation_losses = []
training_accuracies = []
validation_accuracies = []
for e in range(epochs):
print(f'Epoch #{e+1}')
running_loss = 0
running_val_loss = 0
running_train_accuracy = 0
running_val_accuracy = 0
total_train = 0
total_validation = 0
correct_train = 0
correct_validation = 0
grads[e] = {}
for batch_idx, (data, label) in enumerate(dataloader):
with torch.set_grad_enabled(True):
model.train()
optimizer.zero_grad()
data, label = data.to(device), label.to(device)
logits = model(data)
loss = criterion(logits, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predictions = torch.max(logits, 1)
total_train += label.size(0)
correct_train += torch.sum(predictions == label.data)
else:
with torch.no_grad():
model.eval()
for batch_idx, (data, label) in enumerate(validation_loader):
data, label = data.to(device), label.to(device)
logits = model.forward(data)
val_loss = criterion(logits, label)
running_val_loss += val_loss.item()
_, predictions = torch.max(logits, 1)
total_validation += label.size(0)
correct_validation += torch.sum(predictions == label.data)
Really appreciate your time if you went through the entire post