I’m trying to build a CNN-GRU model similar to PDF but with a decoder part. My dataset is a sequence of images related over time and the reason I want to add the decoder part is so I can extrapolate the images in time. The model does show convergence when I train it but the problem lies when I’m making prediction recursive i.e. feeding the predicted image again for multiple time step prediction. The model basically predicts the same image over and over again with no changes. So I wonder if my pytorch implementation is wrong? Training is done using nn.BCELoss().
import torch
import torch.nn as nn
import numpy as np
from torch.nn import init
from torch.autograd import Variable
class ConvGRUCell(nn.Module):
def __init__(self, input_size, hidden_size, kernel_size, g, seqPred):
super(ConvGRUCell, self).__init__()
self.seqToPredict = seqPred
self.GPU = g
padding = kernel_size // 2
self.input_size = input_size
self.hidden_size = hidden_size
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
# Here comes the conv that creates the output i.e. prediction. Add three channels
self.endOutput1 = nn.Conv2d(hidden_size, hidden_size + 3, kernel_size, padding=padding)
self.endOutput2 = nn.Conv2d(hidden_size + 3, input_size, kernel_size, padding=padding)
init.orthogonal_(self.reset_gate.weight)
init.orthogonal_(self.update_gate.weight)
init.orthogonal_(self.out_gate.weight)
init.constant_(self.reset_gate.bias, 0.)
init.constant_(self.update_gate.bias, 0.)
init.constant_(self.out_gate.bias, 0.)
init.orthogonal_(self.endOutput1.weight)
init.orthogonal_(self.endOutput2.weight)
init.constant_(self.endOutput1.bias, 0.)
init.constant_(self.endOutput1.bias, 0.)
def forward(self, x, prev = None):
timeStep = x.data.size()[0]
batch_size = x.data.size()[1]
spatial_size = x.data.size()[3:]
# How many loops
totIter = timeStep + self.seqToPredict
# generate empty prev_state
state_size = [batch_size, self.hidden_size] + list(spatial_size)
if prev is not None:
prev_state = prev
else:
prev_state = Variable(torch.zeros(state_size))
if self.GPU:
prev_state = prev_state.cuda()
for t in range(totIter):
if t < timeStep:
# Encoder part
stacked_inputs = torch.cat([x[t], prev_state], dim=1)
update = torch.sigmoid(self.update_gate(stacked_inputs))
reset = torch.sigmoid(self.reset_gate(stacked_inputs))
out_inputs = torch.tanh(self.out_gate(torch.cat([x[t], prev_state * reset], dim=1)))
prev_state = prev_state * (1 - update) + out_inputs * update
else:
# Decoder part
out = self.endOutput1(prev_state)
out = self.endOutput2(out)
out = torch.sigmoid(out)
if t == timeStep:
outCat = out[None]
else:
outCat = torch.cat((outCat, out[None]))
stacked_inputs = torch.cat([out, prev_state], dim=1)
update = torch.sigmoid(self.update_gate(stacked_inputs))
reset = torch.sigmoid(self.reset_gate(stacked_inputs))
out_inputs = torch.tanh(self.out_gate(torch.cat([out, prev_state * reset], dim=1)))
prev_state = prev_state * (1 - update) + out_inputs * update
outCat = outCat.transpose(0, 1)
# Return last hidden and output
return outCat, prev_state
def main():
model = ConvGRUCell(1, 3, 3, False, seqPred = 3)
# Shape: Timesteps, Batch size, channel, Height, Widht
x = Variable(torch.ones(5, 10, 1, 50, 50))
# Shape: Batch size, # ground truth y_t, channel, h, w
y = Variable(torch.ones(10, 3, 1, 50, 50))
pred, hidd = model(x)
print pred.shape # (10, 3, 1, 50, 50)
if __name__ == '__main__':
main()