Creating a coupled LSTM

Hi all,

I am attempting to reproduce the Neural Assimilation model architecture described in this paper: Neural Assimilation - PMC. They are essentially trying to design an architecture that would be much faster than using a Kalman Filter by bypassing the correction step. Despite the authors providing some code in Keras, I have had some issues trying to implement this with PyTorch as I simply cannot see how they have built the model. (link to the code is here: FINAL CODE - Google Drive)

From my understanding, we need two LSTMs with one using the model forecasted data and the other using the observational data. The tricky part is that each LSTM requires as input the hidden state and cell state from the other LSTM which I am having some difficulty trying to implement. Anyone have any ideas on how I could do this? So far I have only used simple LSTMs for many-to-many prediction.

This is a diagram of the architecture:

and what I am working on thus far as its implementation:

class Neural_Assimilation(nn.Module):
  def __init__(self, n_inputs, n_hidden, num_layers, n_outputs):
    super(Neural_Assimilation, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs
    self.L = num_layers

    self.rnn_f = nn.LSTM(
        input_size=self.D,
        hidden_size=self.M,
        num_layers=self.L,
        batch_first=True)
    
    self.rnn_o = nn.LSTM(
        input_size=self.D,
        hidden_size=self.M,
        num_layers=self.L,
        batch_first=True)
    self.fc = nn.Linear(self.M, self.K) # For BiDir LSTM need to multiply self.M x 2
  
  def forward(self, X_f, X_o):
    # initial hidden states
    h0_f = torch.zeros(self.L, X_f.size(0), self.M).to(device) # For BiDir LSTM need to multiply self.L x 2
    c0_f = torch.zeros(self.L, X_f.size(0), self.M).to(device) # For BiDir LSTM need to multiply self.L x 2

    # get RNN unit output
    out_f, (h_t_f, c_t_f) = self.rnn_f(X_f, (h0_f, c0_f))
    out_o, (h_t_o, c_t_o) = self.rnn_o(X_o, (h_t_f, c_t_f))
    out_f, (h_t_o, c_t_o) = self.rnn_f(X_f, (h_t_o, c_t_o))
    
    # we want h(t) and c(t) at the final time step so we can feed this into the other RNN
    out_f = self.fc(out_f[:, -1, :])

    hidden_state = h_t[:, -1, :]
    hidden_out = self.fc(hidden_state)

    cell_state = c_t[:, -1, :]
    cell_out = self.fc(cell_state)

    return out, hidden_state, cell_out 

Would really appreciate your expertise!