Possible Bug? Different precision when feeding views of input

This is the architecture of a model that I have trained:

class LayerNorm_LSTM(nn.Module):
    def __init__(self, in_dim, hidden_dim, bidirectional=False):
        super(LayerNorm_LSTM, self).__init__()
        self.layernorm = nn.LayerNorm(in_dim)
        self.lstm = nn.LSTM(in_dim, hidden_dim, batch_first=True, bidirectional=bidirectional)
        
    def forward(self, input_, hidden):
        input_ = self.layernorm(input_)
        lstm_out, hidden = self.lstm(input_, hidden)
        return lstm_out, hidden


class Recognizer_Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_hidden, BN_dim, mel_mean, mel_std, bidirectional=False):
        super(Recognizer_Net, self).__init__()
        self.linear_pre1 = nn.Linear(input_dim, hidden_dim)
        self.linear_pre2 = nn.Linear(hidden_dim, hidden_dim)

        self.lstm = nn.ModuleList([
            LayerNorm_LSTM(in_dim=hidden_dim, hidden_dim=hidden_dim, 
                           bidirectional=bidirectional)
            for i in range(num_hidden)
        ])       
        self.BN_linear = nn.Linear(hidden_dim, BN_dim)
        self.tanh = nn.Tanh()

        self.num_hidden = num_hidden
        self.hidden_dim = hidden_dim

        self.mel_mean = mel_mean
        self.mel_std = mel_std

    def init(self):
        hidden_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
        cell_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
        return hidden_states, cell_states

    def forward(self, x, hidden):
        hidden_states, cell_states = hidden

        x = (x - self.mel_mean) / self.mel_std
        pre_linear1 = F.relu(self.linear_pre1(x))
        lstm_out = F.relu(self.linear_pre2(pre_linear1))

        for l, lstm_layer in enumerate(self.lstm):
            lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]) = lstm_layer(lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]))

        BN_out = self.tanh(self.BN_linear(lstm_out))
        return BN_out, (hidden_states, cell_states)

And I have the following code to test inference:

input = torch.ones((1, 300, 80), dtype=torch.float)
recognizer = Recognizer_Net(input_dim=80, hidden_dim=512, num_hidden=3, BN_dim=256, mel_mean=mel_mean, mel_std=mel_std, bidirectional=False)
recognizer.eval()

recognizer_hidden = recognizer.init()
BN_out, recognizer_hidden = recognizer.forward(input, recognizer_hidden)

recognizer_hidden = recognizer.init()
BN_out_test, recognizer_hidden = recognizer.forward(input[:,0:16,:], recognizer_hidden)

recognizer_hidden = recognizer.init()
BN_out_test2, recognizer_hidden = recognizer.forward(input[:,0:15,:], recognizer_hidden)

I find that BN_out[:,0:16,:] == BN_out_test, but then BN_out[:,0:15,:] != BN_out_test2. Here are the printed results accordingly:

(Pdb) p BN_out[:,:16,:]
tensor([[[ 0.12310354, -0.06452109,  0.04474466,  ...,  0.06025446,
          -0.02436436,  0.03219125],
         [ 0.16362280, -0.08077041,  0.07871316,  ...,  0.09137958,
          -0.04021315,  0.03032400],
         [ 0.18024816, -0.08485510,  0.09805698,  ...,  0.10307459,
          -0.04448149,  0.02867011],
         ...,
         [ 0.18468492, -0.06737500,  0.14588971,  ...,  0.09201521,
          -0.00267485,  0.00157470],
         [ 0.18401137, -0.06671639,  0.14601991,  ...,  0.09149401,
          -0.00209353, -0.00024496],
         [ 0.18337716, -0.06622121,  0.14607997,  ...,  0.09106418,
          -0.00173651, -0.00173810]]])
(Pdb) p BN_out_test
tensor([[[ 0.12310354, -0.06452109,  0.04474466,  ...,  0.06025446,
          -0.02436436,  0.03219125],
         [ 0.16362280, -0.08077041,  0.07871316,  ...,  0.09137958,
          -0.04021315,  0.03032400],
         [ 0.18024816, -0.08485510,  0.09805698,  ...,  0.10307459,
          -0.04448149,  0.02867011],
         ...,
         [ 0.18468492, -0.06737500,  0.14588971,  ...,  0.09201521,
          -0.00267485,  0.00157470],
         [ 0.18401137, -0.06671639,  0.14601991,  ...,  0.09149401,
          -0.00209353, -0.00024496],
         [ 0.18337716, -0.06622121,  0.14607997,  ...,  0.09106418,
          -0.00173651, -0.00173810]]])
(Pdb) p BN_out_test2
tensor([[[ 0.12310345, -0.06452110,  0.04474467,  ...,  0.06025450,
          -0.02436449,  0.03219126],
         [ 0.16362284, -0.08077045,  0.07871307,  ...,  0.09137955,
          -0.04021320,  0.03032403],
         [ 0.18024816, -0.08485514,  0.09805696,  ...,  0.10307456,
          -0.04448152,  0.02867013],
         ...,
         [ 0.18537879, -0.06823060,  0.14564055,  ...,  0.09265027,
          -0.00357649,  0.00375958],
         [ 0.18468492, -0.06737494,  0.14588960,  ...,  0.09201526,
          -0.00267488,  0.00157468],
         [ 0.18401131, -0.06671640,  0.14601973,  ...,  0.09149403,
          -0.00209353, -0.00024493]]])

As you can see, there are slight differences in the lower-precision decimal points. I am confused on why this is happening… Also I’m confused on why taking the view of an input larger than [:,0:15,:] keeps the same precision. Does anyone know why this is happening? Is this a bug in PyTorch?

This is probably expected as different shapes can cause different algorithms/kernels to be run which will expectedly have numerical variations. If this is a problem and you are on an Ampere GPU, you can see if setting export NVIDIA_TF32_OVERRIDE=0 improves the relative difference or (as an extreme measure) using double precision improves things.