Freezing API LSTM has different hidden state

I wish to freeze the following model Recognizer_Net which is structured as:

class LayerNorm_LSTM(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, bidirectional=False):
        super(LayerNorm_LSTM, self).__init__()
        self.layernorm = nn.LayerNorm(in_dim)
        self.lstm = nn.LSTM(in_dim, hidden_dim, batch_first=True, bidirectional=bidirectional)
        
    @torch.jit.export
    def forward(
            self, input_: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        input_ = self.layernorm(input_)
        lstm_out, hidden = self.lstm(input_, hidden)
        return lstm_out, hidden


class Recognizer_Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_hidden, BN_dim, mel_mean, mel_std, bidirectional=False):
        super(MT_LSTM_BN_LayerNorm_Net, self).__init__()
        self.linear_pre1 = nn.Linear(input_dim, hidden_dim)
        self.linear_pre2 = nn.Linear(hidden_dim, hidden_dim)

        self.lstm = nn.ModuleList([
            LayerNorm_LSTM(in_dim=hidden_dim, hidden_dim=hidden_dim, 
                           bidirectional=bidirectional)
            for i in range(num_hidden)
        ])       
        self.BN_linear = nn.Linear(hidden_dim, BN_dim)
        self.tanh = nn.Tanh()

        self.num_hidden = num_hidden
        self.hidden_dim = hidden_dim

        self.mel_mean = mel_mean
        self.mel_std = mel_std

    @torch.jit.export
    def init(self) -> Tuple[torch.Tensor, torch.Tensor]:
        hidden_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
        cell_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
        return hidden_states, cell_states

    @torch.jit.export
    def forward(
            self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        hidden_states, cell_states = hidden

        x = (x - self.mel_mean) / self.mel_std
        pre_linear1 = F.relu(self.linear_pre1(x))
        lstm_out = F.relu(self.linear_pre2(pre_linear1))

        for l, lstm_layer in enumerate(self.lstm):
            lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]) = lstm_layer(lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]))

        BN_out = self.tanh(self.BN_linear(lstm_out))
        return BN_out, (hidden_states, cell_states)

I compile the JIT model as follows:

recognizer = Recognizer_Net(input_dim=80, hidden_dim=512, num_hidden=3, BN_dim=256, mel_mean=mel_mean, mel_std=mel_std, bidirectional=False).cpu()  
recognizer.load_state_dict(checkpoint)                    
recognizer.eval()
net_jit = torch.jit.script(recognizer)
net_jit = torch.jit.freeze(net_jit, preserved_attrs=["init"])

Now when I compare the non-JIT model with this newly frozen JIT model, I find that the returned BN_out is consistent, the returned cell_states is consistent, BUT the returned hidden_states is not consistent. Such as:

input = torch.zeros((1, 300, 80), dtype=torch.float)

recognizer_jit = torch.jit.load(frozen_jit_model_path, map_location='cpu')
recognizer_hidden_jit = recognizer_jit.init()
BN_out_jit, recognizer_hidden_jit = recognizer_jit.forward(input, recognizer_hidden_jit)

recognizer = Recognizer_Net(input_dim=80, hidden_dim=512, num_hidden=3, BN_dim=256, mel_mean=mel_mean, mel_std=mel_std, bidirectional=False).cpu()
recognizer.load_state_dict(checkpoint)                  
recognizer.eval()
recognizer_hidden = recognizer.init()
BN_out, recognizer_hidden = recognizer.forward(input, recognizer_hidden)

I get that BN_out_jit == BN_out and recognizer_hidden_jit[1] == recognizer_hidden[1], but recognizer_hidden_jit[0] != recognizer_hidden[0].

Note that I do not get this problem if I do not use the Freezing API when compiling the JIT model. Is this a bug with the Freezing API? Or am I missing some steps in the freezing of the jit model?