LSTM model performance depends on batch size even in eval mode()

I have model that uses LSTM and full connected layer

Model(
     (lstm): LSTM(3, 32, num_layers=3, batch_first=True, dropout=0.7)  
     (dense): Linear(in_features=32, out_features=2, bias=True) 
)

I am training and testing my model using MPS on Apple M2 chip

For loss function I am using Cross Entropy Loss with computed weights for each class and for optimisation I am using AdamW, F-Score is measured classification_report from sklearn library

The problem is when batch size of test and train dataset is 64, model performance grows as expected. But when batch size of test dataset is 256, performance will drop massively and not grow any further

On plot below you can see performances of batch sizes. Pink graph represents batch size of test dataset of 64, blue represents batch size of 256

Train F-Score is computed on train dataset with batch size of 64. It faces same issue with different performances dependent on batch size

I am also always set my model to eval() or train() mode


for epoch in range(epochs):

    model.train()
    print(f'Epoch {epoch}')
    __train_loop(model, train_dataloader, loss_function, optimizer, scheduler, verbose, device=device)

    model.eval()
    train_accuracy, train_f_score = test_model(model, train_dataloader, device=device)

    print(f'Train accuracy: {train_accuracy}')
    print(f'Train F-Score: {train_f_score}')

    accuracy, f_score = test_model(model, test_dataloader, device=device)

training loop:

for batch_id, (X, y) in enumerate(train_dataloader):

            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()

            y_pred = model(X)
            loss = loss_function(y_pred, y)

            loss.backward()
            optimizer.step()

testing model

y_pred_all = []
y_all = []
 with torch.no_grad():

            for batch_id, (X, y) in enumerate(test_dataloader):
                X, y = X.to(device), y.to(device)

                y_pred = model(X)
                y_pred = torch.argmax(y_pred, dim=1)

                y_pred_all.append(y_pred.cpu().numpy())
                y_all.append(y.cpu().numpy())
                bar.update(batch_id)

    y_pred_all = np.hstack(y_pred_all).flatten()
    y_all = np.hstack(y_all).flatten()

    cr = classification_report(y_all, y_pred_all, output_dict=True)
    f_score = cr['macro avg']['f1-score']
    accuracy = cr['accuracy']

I expect performance not change drastically when changing batch size of dataset, influence of batch size on model’s output is odd here, because I am not using batch norm. It seems for me that for some reason, LSTM’s output is dependednt on batch size

There are several blogs which explain the impact of batch size in the model training. It suggests that the smaller batch size gives rise to better performance. This is also verified in your case.
https://medium.com/analytics-vidhya/when-and-why-are-batches-used-in-machine-learning-acda4eb00763

https://medium.com/geekculture/how-does-batch-size-impact-your-model-learning-2dd34d9fb1fa

Thanks for answer, but I am training my model with small batch size. In my case, batch_size influences testing performance of model. Models trained with batch size of 64, while tested on batch size of 64 and 256

I am wondering if you can check the parameter: batch_first in the model. It may affect the output format when calling the model in prediction. Maybe under batch_first = True configuration, calling the model with a different batch_size is not what you expect.

I checked model’s behaviour with batch_size = False , it is the same as with batch_size = True

Can you post the complete code if you model class incl. the forward() method?

Hello,

here is forward() code, thanks for answer

    def forward(self, X: torch.Tensor):

        self.lstm.flatten_parameters()
        
        output, (hidden_state, cell_state) = self.lstm(X)
        return self.dense(hidden_state[-1])

A bit off-topic, but is

self.lstm.flatten_parameters()

really needed? And if so, shouldn’t it suffice to have it once at the end of _init_()?

You mentioned that you tried with batch_first=True and batch_first=False? Did you adjust shape often X accordingly? What’s the shape of X anyway? Do you perform any reshape() and/or view() operation on X?