Inconsistent LSTM batched vs. unbatched error during inference

Hi

I’m trying to debug some inconsistencies I get for batched vs unbatched inference with LSTM.
In short, I get different outputs depending on whether I compute the output in one batch or pass elements one-by-one to the LSTM. These differences are somewhat small, but my ultimate application relies on these LSTMs in succession, so the errors accumulate.

Additionally, I notice that this error even changes with completely unrelated changes to the code! For example, if I add some model initialization, the batched-vs-unbatched error can change significantly.

I use a specific set of pretrained weights, you can find them here. You can reproduce the issues with the following code snippet:

import torch
import random
import numpy as np


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(32,
                                  64,
                                  1,
                                  batch_first=True,
                                  bidirectional=True)
        self.linear = torch.nn.Linear(128, 32)
    
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        linear_out = self.linear(lstm_out)
        return linear_out

if __name__ == "__main__":
    torch.use_deterministic_algorithms(True)
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)

    # UNCOMMENT THIS, AND ERROR BECOMES 7.0512e-05
    # my_model = MyModel()

    # UNCOMMENT THIS AS WELL, AND ERROR BECOMES 0.0009
    # my_model = MyModel()

    my_model = MyModel()
    my_model.load_state_dict(torch.load('my_state_dict.pt'))
    my_model.eval()

    X = torch.randn(2, 288, 32)
    with torch.no_grad():
        Y_batched = my_model(X)
        
        Y_individiual = torch.zeros_like(X)
        for i in range(Y_individiual.shape[0]):
            Y_individiual[i:i+1] = my_model(X[i:i+1])

    # ERROR IS 0.0001
    print(torch.abs(Y_batched - Y_individiual).max())

I know that there is some tolerance with floating points and such, but I don’t see why it would affect the outputs in other batches, and I certainly don’t understand why changing a line of unrelated code would change any behavior.

Any help would be appreciated. Thanks!

Hello,

I am still looking for help on this issue.
Any ideas? Thanks.

After some debugging, I think the error is sufficiently small and I had some problems in another part of the code.
As for the inconsistent error when uncommenting “unrelated” lines, I think these lines were changing the random state so that when I initialized the tensor X, it was different every time.