Next frame prediction using CNN-LSTM

I’m doing next frame prediction from static images extracted from video and save into disk. I’m using CNN-LSTM, during training feed the model 5 frames and predict the 6th frame, but during evaluation I want the CNN-LSTM model to take it’s prediction and use it as input to predict the next future frame it should repeat until predict the 6th frame.

How do you pass the predicted frame from then CNN-LSTM as an input in the next step for 6 iterations? Here is the sample code for my CNN-LSTM model:

class CRNN(nn.Module):
def __init__(self, in_channels=3, sample_size=64, num_classes=100,
            hidden_size=512, num_layers=1, rnn_unit='LSTM'):
    super(CRNN, self).__init__()
    self.in_channels=in_channels
    self.sample_size = sample_size
    self.num_classes = num_classes
    self.rnn_unit=rnn_unit

    # network params
    self.ch1, self.ch2= 64, 128
    self.k1, self.k2 = (5, 5), (5, 5)
    self.s1, self.s2 = (1, 1), (1, 1)
    self.p1, self.p2 = (0, 0), (0, 0)
    self.d1, self.d2 = (1, 1), (1, 1)
    self.input_size  = self.ch2
    self.hidden_size = hidden_size
    self.num_layers  = num_layers

    # network architecture
    # in_channels=3 for rgb
    self.conv1 = nn.Sequential(
        nn.Conv2d(in_channels=self.in_channels, out_channels=self.ch1, kernel_size=self.k1, stride=self.s1, padding=self.p1, dilation=self.d1),
        nn.BatchNorm2d(self.ch1, momentum=0.01),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5),
        nn.Conv2d(in_channels=self.ch1, out_channels=self.ch1, kernel_size=1, stride=1),
        nn.MaxPool2d(kernel_size=2),
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(in_channels=self.ch1, out_channels=self.ch2, kernel_size=self.k2, stride=self.s2, padding=self.p2, dilation=self.d2),
        nn.BatchNorm2d(self.ch2, momentum=0.01),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5),
        nn.Conv2d(in_channels=self.ch2, out_channels=self.ch2, kernel_size=1, stride=1),
        nn.MaxPool2d(kernel_size=2),
    )

    if self.rnn_unit=='LSTM':
        self.lstm = nn.LSTM(
        input_size=self.input_size,
        hidden_size=self.hidden_size,
        dropout= 0.5 if self.num_layers > 1 else 0,
        num_layers=self.num_layers,
        batch_first=True,
    )
    if self.rnn_unit=='GRU':
        self.lstm = nn.GRU(
        input_size=self.input_size,
        hidden_size=self.hidden_size,
        dropout= 0.5 if self.num_layers > 1 else 0,
        num_layers=self.num_layers,
        batch_first=True,
    )
    if self.rnn_unit=='RNN':
        self.lstm = nn.RNN(
        input_size=self.input_size,
        hidden_size=self.hidden_size,
        dropout= 0.5 if self.num_layers > 1 else 0,
        num_layers=self.num_layers,
        batch_first=True,
    )

  
    self.fc1 = nn.Linear(self.hidden_size, self.num_classes)
    self.act = nn.Sigmoid()

def forward(self, X, n_steps=None):
    # CNN
    # x: (batch_size, channel, t, h, w)
    cnn_embed_seq = []
    if n_steps is None: # Code for training
        for t in range(X.size(2)):
            # Conv
            x = self.conv1(X[:, :, t, :, :])
            x = self.conv2(x)

            x = x.view(x.size(0), -1)
            cnn_embed_seq.append(x)

        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0)
        
        # batch first
        cnn_embed_seq = cnn_embed_seq.transpose_(0, 1)

        # LSTM
        hidden=None
        # use faster code paths
        self.lstm.flatten_parameters()
        out, hidden = self.lstm(cnn_embed_seq, hidden)
        # MLP
        # out: (batch, seq, feature), choose the last time step
        out = self.fc1(out[:, -1, :])
        out = self.act(out)

    if n_steps is not None: # Code for prediction
        for t in range(n_steps): 
            # Conv
            x = self.conv1(X[:, :, 0, :, :])
            x = self.conv2(x)
  
            x = x.view(x.size(0), -1)
            # cnn_embed_seq.append(x)

            # cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0)
            # batch first
            # x = x.transpose_(0, 1)
            # print('pred:',x.shape, t)

            x = x[ :, np.newaxis, :]
            # LSTM
            hidden=None
            # use faster code paths
            self.lstm.flatten_parameters()
            out, hidden = self.lstm(x, hidden)
            # print('LSTM out', out.shape)
            # MLP
            # out: (batch, seq, feature), choose the last time step
            out = self.fc1(out[:, -1, :])
            # X[:, 1, :, :, :] = out.view(-1, 1, self.sample_size,self.sample_size)
            out = self.act(out)
            X[:, 1, :, :, :] = out.view(-1, 1, self.sample_size,self.sample_size)

    return out