Differing outputs from LSTM networks with the same number of stacked layers

Hello everyone,

I am new to PyTorch and I have recently started playing around with RNNs, more specifically, LSTMs.

While exploring them, I noticed that there was a difference between using an LSTM object with num_layers=3 versus 3 LSTM objects, each with num_layers=1. A snippet of my nn.Module code is as such (please forgive my coding style):

class My_net(torch.nn.Module):

    # input_dim = 5, hidden_dim and output_dim=2 in my case
    def __init__(self, input_dim, hidden_dim, output_dim, use_gpu=True):
        super(My_net, self).__init__()
        self.use_gpu = use_gpu
        self.h_size = hidden_dim
        layers = 3
        self.lstm1 = torch.nn.LSTM(input_size=input_dim, hidden_size=hidden_dim,
                                   num_layers=1, batch_first=True)

        self.lstm2 = torch.nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim,
                                   num_layers=1, batch_first=True)

        self.lstm3 = torch.nn.LSTM(input_size=hidden_dim, hidden_size=output_dim,
                                   num_layers=1, batch_first=True)

        self.lstmtest = torch.nn.LSTM(input_size=input_dim, hidden_size=output_dim,
                                   num_layers=layers, batch_first=True)


    def forward(self, x):
        # x has shape (batch_size, seq_length, dim)
        # and is fed into lstm1 and lstmtest
        op, states = self.lstm1(x)
        op1, states1 = self.lstm2(op)
        op2, states2 = self.lstm3(op1)
       
        optest, statestest = self.lstmtest(x)  # Just for comparison purposes only
        return op2

If I understood PyTorch documentation correctly, op2 and optest should give a similar result correct? Of course, the way the weights are initialized in the LSTM objects are different and as such exact outputs are not possible. However, the results I am getting are:

Variable containing:
(0 ,.,.) = 
  4.6599e-16  2.0737e-16
  2.6191e-15  1.1656e-15
  2.1194e-15  9.4315e-16
           ⋮            
  1.4646e-16  6.5178e-17
 -2.9070e-15 -1.2936e-15
 -1.1162e-14 -4.9673e-15

(1 ,.,.) = 
  2.8576e-15  1.2717e-15
  8.6747e-15  3.8604e-15
  5.6182e-15  2.5002e-15
           ⋮            
  8.8361e-15  3.9322e-15
  1.1098e-14  4.9386e-15
  1.3680e-14  6.0877e-15

(2 ,.,.) = 
  9.2943e-16  4.1361e-16
  5.6716e-15  2.5240e-15
  6.0766e-15  2.7042e-15
           ⋮            
  1.4192e-14  6.3158e-15
  1.3274e-14  5.9072e-15
  8.6829e-15  3.8640e-15

(29,.,.) = 
 -2.4198e-15 -1.0769e-15
 -6.7929e-15 -3.0230e-15
 -8.0617e-15 -3.5876e-15
           ⋮            
  2.9668e-15  1.3203e-15
  2.6790e-15  1.1922e-15
  7.7745e-16  3.4598e-16

(30,.,.) = 
 -5.9139e-15 -2.6318e-15
 -4.0896e-15 -1.8199e-15
 -8.1629e-16 -3.6326e-16
           ⋮            
 -1.8212e-15 -8.1048e-16
  6.0243e-15  2.6809e-15
  4.8011e-15  2.1366e-15

(31,.,.) = 
  1.7016e-15  7.5724e-16
  1.1792e-14  5.2474e-15
  1.5820e-14  7.0404e-15
           ⋮            
 -4.5726e-15 -2.0349e-15
 -1.0052e-14 -4.4731e-15
 -1.5208e-14 -6.7679e-15

for op2 and as follows for optest:

Variable containing:
(0 ,.,.) = 
  3.9907e-02  8.7602e-08
  5.9702e-02  1.7235e-07
  6.9532e-02  2.5435e-07
           ⋮            
  7.9308e-02  2.4525e-06
  7.9308e-02  2.4604e-06
  7.9308e-02  2.4680e-06

(1 ,.,.) = 
  3.9907e-02  8.7602e-08
  5.9702e-02  1.7235e-07
  6.9532e-02  2.5435e-07
           ⋮            
  7.9308e-02  2.4525e-06
  7.9308e-02  2.4604e-06
  7.9308e-02  2.4680e-06

(2 ,.,.) = 
  3.9907e-02  8.7602e-08
  5.9702e-02  1.7235e-07
  6.9532e-02  2.5435e-07
           ⋮            
  7.9308e-02  2.4525e-06
  7.9308e-02  2.4604e-06
  7.9308e-02  2.4680e-06
...

(29,.,.) = 
  3.9907e-02  8.7602e-08
  5.9702e-02  1.7235e-07
  6.9532e-02  2.5435e-07
           ⋮            
  7.9308e-02  2.4525e-06
  7.9308e-02  2.4604e-06
  7.9308e-02  2.4680e-06

(30,.,.) = 
  3.9907e-02  8.7602e-08
  5.9702e-02  1.7235e-07
  6.9532e-02  2.5435e-07
           ⋮            
  7.9308e-02  2.4525e-06
  7.9308e-02  2.4604e-06
  7.9308e-02  2.4680e-06

(31,.,.) = 
  3.9907e-02  8.7602e-08
  5.9702e-02  1.7235e-07
  6.9532e-02  2.5435e-07
           ⋮            
  7.9308e-02  2.4525e-06
  7.9308e-02  2.4604e-06
  7.9308e-02  2.4680e-06

For optest, there are repeated values for each batch of outputs I got and this is a very strange behaviour. Furthermore, this was not found in op2. Could anyone please give me some advice as to why this is so? Or perhaps it is just some silly coding error on my part?

If more information is required please do not hesitate to tell me. I will try to provide the necessary details.

Thank you all in advance for your help!

EDIT:
An example of my inputs (in case it is needed) are:

# shape of (75, 5) per batch
(0 ,.,.) = 
  8.9588e-01  9.0243e-01  7.5308e-01 -8.7843e-01  3.1841e-01
 -2.5644e-01  2.6785e-01  8.0229e-01  4.5353e-01  7.4789e-01
 -6.7627e-01  1.3303e-02 -2.7766e-01  2.0192e+00 -4.9631e-01
                             ⋮                              
  7.9304e-01 -1.2156e+00  1.0258e+00 -4.7847e-02  2.2743e-01
 -1.4877e+00 -7.5960e-01 -3.0052e-01 -9.7146e-01 -9.6701e-01
  4.9587e-01  3.5543e-01  4.9934e-02  9.5471e-01 -1.9323e+00

I have only provided a batch as the nature of it is the same throughout the batches.