Model doesn't learn much after many epochs [Porting CNN-LSTM model from Keras to Pytorch]

Hello,

First of all I appreciate that the two frameworks are different and cannot be expected to replicate results. I feel my issue here is that I am not correctly training the model or wiring up the nodes correctly, because according to me they use similar building blocks that are not extremely different and the Pytorch code should learn something regardless of weight initialization. I have tried a bunch of things and will try to elaborate on what I’ve tried below:

  • I find that the Pytorch model starts off with a similar loss and initial accuracy for both the train set and the validation sets, but whereas the Keras model begins to improve in validation and training accuracy after 25-30 epochs, the Pytorch model seems to not improve more than fractionally even for 100 epochs. The initial losses and accuracies give me some hope that the model definition is somewhat correct and maybe there is an issue with the training loop

  • I have manually computed the paddings as “same” padding is not available in Pytorch yet (fingers crossed for 1.9 :slight_smile: ) and used integer labels instead of 1-Hot encoding as in the reference

  • The shapes seem to match Keras shapes at all layers (of course the filter dimension precedes the timesteps x features dimension where the convention differs)

  • The pytorch LSTM module seems to have less learnable parameters than the keras LSTM, (About 30k for T=50 timsteps) does this mean the keras LSTM is a larger abstraction of a basic LSTM with additional layers? I do believe the implementations of the LSTM formulae cannot differ by that much.

  • The losses change fractionally and the gradients do change but it appears that there may be a vanishing gradients problem somewhere as they are especially small for some of the conv layers

  • Do I need to detach certain parts of the LSTM if I am only using the last feature map in the output?

Trend on Reference (T=50)

Reference to Keras Model :-

    # Params :        T : Number of time steps
    #                 NF : Number features
    #                 number_of_lstm : LSTM Features
    input_lmd = Input(shape=(T, NF, 1))
    
    # build the convolutional block
    conv_first1 = Conv2D(32, (1, 2), strides=(1, 2))(input_lmd)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)

    conv_first1 = Conv2D(32, (1, 2), strides=(1, 2))(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)

    conv_first1 = Conv2D(32, (1, 10))(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    conv_first1 = Conv2D(32, (4, 1), padding='same')(conv_first1)
    conv_first1 = keras.layers.LeakyReLU(alpha=0.01)(conv_first1)
    
    # build the inception module
    convsecond_1 = Conv2D(64, (1, 1), padding='same')(conv_first1)
    convsecond_1 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_1)
    convsecond_1 = Conv2D(64, (3, 1), padding='same')(convsecond_1)
    convsecond_1 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_1)

    convsecond_2 = Conv2D(64, (1, 1), padding='same')(conv_first1)
    convsecond_2 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_2)
    convsecond_2 = Conv2D(64, (5, 1), padding='same')(convsecond_2)
    convsecond_2 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_2)

    convsecond_3 = MaxPooling2D((3, 1), strides=(1, 1), padding='same')(conv_first1)
    convsecond_3 = Conv2D(64, (1, 1), padding='same')(convsecond_3)
    convsecond_3 = keras.layers.LeakyReLU(alpha=0.01)(convsecond_3)
    
    convsecond_output = keras.layers.concatenate([convsecond_1, convsecond_2, convsecond_3], axis=3)
    conv_reshape = Reshape((int(convsecond_output.shape[1]), int(convsecond_output.shape[3])))(convsecond_output)

    # build the last LSTM layer
    conv_lstm = LSTM(number_of_lstm)(conv_reshape)

    # build the output layer
    out = Dense(3, activation='softmax')(conv_lstm)
    model = Model(inputs=input_lmd, outputs=out)
    adam = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1)
    model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

Pytorch Model

  def conv_2d(input_filters, output_filters, kernel_size, padding=0, stride=1):
    return nn.Sequential(
      nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, padding=padding, stride=stride),
      nn.LeakyReLU(inplace=True),
  )

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        torch.nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.Linear:
            m.bias.data.fill_(0)

class DeepLOB(nn.Module):
    def __init__(self, T, NF, no_lstm, input_filters=1):
        super().__init__()

        # Initial Convolution Layers
        self.conv_first1_1 = conv_2d(input_filters, 32, (1,2), stride=(1,2))
        self.conv_first1_2 = conv_2d(32, 32, (4,1))
        self.conv_first1_3 = conv_2d(32, 32, (4,1))

        self.conv_first1_4 = conv_2d(32, 32, (1,2), stride=(1,2))
        self.conv_first1_5 = conv_2d(32, 32, (4,1))
        self.conv_first1_6 = conv_2d(32, 32, (4,1))

        self.conv_first1_7 = conv_2d(32, 32, (1,10))
        self.conv_first1_8 = conv_2d(32, 32, (4,1))
        self.conv_first1_9 = conv_2d(32, 32, (4,1))

        # "Inception Module" as implemented in reference
        self.incept1 = conv_2d(32, 64, (1,1))
        self.incept2 = conv_2d(64, 64, (3,1), padding=(1, 0))

        self.incept3 = conv_2d(32, 64, (1,1))
        self.incept4 = conv_2d(64, 64, (5,1), padding=(2, 0))

        self.incept5 = nn.MaxPool2d((3,1), stride=(1,1), padding=(1,0))
        self.incept6 = conv_2d(32, 64, (1,1))
        
        # # build the last LSTM layer
        self.conv_lstm = nn.LSTM(T, no_lstm, batch_first=True)
        self.fc = nn.Linear(no_lstm, 3)
        
    def forward(self, x):
        out = self.conv_first1_1(x)
        out = self.conv_first1_2(F.pad(out, (0, 0, 1, 2)))
        out = self.conv_first1_3(F.pad(out, (0, 0, 1, 2)))
        out = self.conv_first1_4(out)
        out = self.conv_first1_5(F.pad(out, (0, 0, 1, 2)))
        out = self.conv_first1_6(F.pad(out, (0, 0, 1, 2)))
        out = self.conv_first1_7(out)
        out = self.conv_first1_8(F.pad(out, (0, 0, 1, 2)))
        out = self.conv_first1_9(F.pad(out, (0, 0, 1, 2)))

        incept1 = self.incept1(out)
        incept1 = self.incept2(incept1)
        incept2 = self.incept3(out)
        incept2 = self.incept4(incept2)

        incept3 = self.incept5(out)
        incept3 = self.incept6(incept3)

        cat_layer = torch.cat([incept1, incept2, incept3], axis=1)
        reshape = cat_layer.view(cat_layer.shape[0:3])
        lstm_out, __ = self.conv_lstm(reshape)
        lstm_out = lstm_out[:, -1, :]

        return self.fc(lstm_out)

Training Keras

#Leaving out the data loading specifics
testX_CNN (Shape of [N x T x NF x 1 ])
trainY_CNN [Same as above except in batch dimesion] 
trainY_CNN = np_utils.to_categorical(trainY_CNN, 3)
testY_CNN = np_utils.to_categorical(testY_CNN, 3)
model.fit(trainX_CNN, trainY_CNN, epochs=200, batch_size=64, verbose=2, validation_data=(testX_CNN, testY_CNN))

Training Pytorch


# Reshape as Pytorch convolution layers expect filter dimension first
# Avoid 1 hot as Pytorch CrossEntropy loss works with integer labels
# Create TensorDataset  and create a loader from it

trainX_CNN = torch.Tensor(trainX_CNN).reshape([-1, trainX_CNN.shape[3], trainX_CNN.shape[1], trainX_CNN.shape[2]]) # transform to torch tensor
trainY_CNN = torch.Tensor(trainY_CNN).long()
testX_CNN = torch.Tensor(testX_CNN).reshape([-1, testX_CNN.shape[3], testX_CNN.shape[1], testX_CNN.shape[2]]) # transform to torch tensor
testY_CNN = torch.Tensor(testY_CNN).long()

dataset = TensorDataset(trainX_CNN, trainY_CNN)
dataset_val = TensorDataset(testX_CNN, testY_CNN) 
 
dataloader = DataLoader(dataset, batch_size=64, num_workers=2, shuffle=True)
validation_loader = DataLoader(dataset_val, batch_size=64, num_workers=2)

model = DeepLOB(T, NF, no_of_lstm)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = model.to(device)
model.apply(init_weights)

epochs = 200
training_losses = []
validation_losses = []
training_accuracies = []
validation_accuracies = []

for e in range(epochs):
    print(f'Epoch #{e+1}')
    running_loss = 0
    running_val_loss = 0
    running_train_accuracy = 0
    running_val_accuracy = 0
    total_train = 0
    total_validation = 0
    correct_train = 0
    correct_validation = 0
    grads[e] = {}

    for batch_idx, (data, label) in enumerate(dataloader):
        with torch.set_grad_enabled(True):
            model.train()
            
            optimizer.zero_grad()
            data, label = data.to(device), label.to(device)
            logits = model(data)
            loss = criterion(logits, label)
            loss.backward()
            
            optimizer.step()

        running_loss += loss.item()
        _, predictions = torch.max(logits, 1)
        
        total_train += label.size(0) 
        correct_train += torch.sum(predictions == label.data)

    else:
        with torch.no_grad():
            model.eval()
            for batch_idx, (data, label) in enumerate(validation_loader):
                data, label = data.to(device), label.to(device)
                logits = model.forward(data)
                val_loss = criterion(logits, label)
                running_val_loss += val_loss.item()
                _, predictions = torch.max(logits, 1)
                total_validation += label.size(0)
                correct_validation += torch.sum(predictions == label.data)

Really appreciate your time if you went through the entire post :slight_smile:

That’s an interesting finding and I would probably start by comparing this layer in isolation to narrow down where this difference is coming from.

If you are passing the states to the nn.LSTM in PyTorch you could detach them to make sure Autograd won’t backpropagate through the previous time steps. Based on your code you are passing the inputs only to this layer so that the states should be initialized as zeros internally.

1 Like

Wow thanks while I had made this observation before I didn’t think to try to debug them in isolation and while trying to work with one keras and one pytorch model with only 1 LSTM unit, I noticed that I had erroneously passed the number of timesteps as the input space size for the torch LSTM without realizing that it is intended to be the feature dimension.

Definitely know what to fix here, I think a reshape of swapping dimensions and initializing the LSTM correctly should do the trick, I will let you know if that works for me.

Thanks for the insight!

Unfortunately that did not help the learnable parameters is about the same now but the loss and accuracy scores still do not progress

1 Like

Hey Gaurav,

I’ve been struggling with the same issues above as well. Have you figured out any other ways to make the learning done as they did in the paper for deepLOB?

Not yet unfortunately, I’m presently using the keras code for my experiments at the moment and will revisit this when I have more bandwidth