CNN-GRU sequence prediction problem

I’m trying to build a CNN-GRU model similar to PDF but with a decoder part. My dataset is a sequence of images related over time and the reason I want to add the decoder part is so I can extrapolate the images in time. The model does show convergence when I train it but the problem lies when I’m making prediction recursive i.e. feeding the predicted image again for multiple time step prediction. The model basically predicts the same image over and over again with no changes. So I wonder if my pytorch implementation is wrong? Training is done using nn.BCELoss().

import torch
import torch.nn as nn
import numpy as np
from torch.nn import init
from torch.autograd import Variable


class ConvGRUCell(nn.Module):

    def __init__(self, input_size, hidden_size, kernel_size, g, seqPred):
        super(ConvGRUCell, self).__init__()

        self.seqToPredict = seqPred
        self.GPU = g
        padding = kernel_size // 2
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)


        # Here comes the conv that creates the output i.e. prediction. Add three channels
        self.endOutput1 = nn.Conv2d(hidden_size, hidden_size + 3, kernel_size, padding=padding)
        self.endOutput2 = nn.Conv2d(hidden_size + 3, input_size, kernel_size, padding=padding)

        init.orthogonal_(self.reset_gate.weight)
        init.orthogonal_(self.update_gate.weight)
        init.orthogonal_(self.out_gate.weight)
        init.constant_(self.reset_gate.bias, 0.)
        init.constant_(self.update_gate.bias, 0.)
        init.constant_(self.out_gate.bias, 0.)

        init.orthogonal_(self.endOutput1.weight)
        init.orthogonal_(self.endOutput2.weight)
        init.constant_(self.endOutput1.bias, 0.)
        init.constant_(self.endOutput1.bias, 0.)

    def forward(self, x, prev = None):

        timeStep = x.data.size()[0]
        batch_size = x.data.size()[1]
        spatial_size = x.data.size()[3:]

        # How many loops
        totIter = timeStep + self.seqToPredict

        # generate empty prev_state
        state_size = [batch_size, self.hidden_size] + list(spatial_size)

        if prev is not None:
            prev_state = prev
        else:
            prev_state = Variable(torch.zeros(state_size))

        if self.GPU:
            prev_state = prev_state.cuda()


        for t in range(totIter):

            if t < timeStep:

                # Encoder part
                stacked_inputs = torch.cat([x[t], prev_state], dim=1)
                update = torch.sigmoid(self.update_gate(stacked_inputs))
                reset = torch.sigmoid(self.reset_gate(stacked_inputs))
                out_inputs = torch.tanh(self.out_gate(torch.cat([x[t], prev_state * reset], dim=1)))
                prev_state = prev_state * (1 - update) + out_inputs * update

            else:
                # Decoder part
                out = self.endOutput1(prev_state)
                out = self.endOutput2(out)
                out = torch.sigmoid(out)

                if t == timeStep:
                    outCat = out[None]
                else:
                    outCat = torch.cat((outCat, out[None]))

                stacked_inputs = torch.cat([out, prev_state], dim=1)
                update = torch.sigmoid(self.update_gate(stacked_inputs))
                reset = torch.sigmoid(self.reset_gate(stacked_inputs))
                out_inputs = torch.tanh(self.out_gate(torch.cat([out, prev_state * reset], dim=1)))
                prev_state = prev_state * (1 - update) + out_inputs * update

        outCat = outCat.transpose(0, 1)
        
        # Return last hidden and output
        return outCat, prev_state




def main():

    model = ConvGRUCell(1, 3, 3, False, seqPred = 3)
    
    # Shape: Timesteps, Batch size, channel, Height, Widht
    x = Variable(torch.ones(5, 10, 1, 50, 50))

    # Shape: Batch size, # ground truth y_t, channel, h, w
    y = Variable(torch.ones(10, 3, 1, 50, 50))

    pred, hidd = model(x)

    print pred.shape # (10, 3, 1, 50, 50)


if __name__ == '__main__':
    main()