def forward(self, x):
batch_size, timesteps, C,H, W = x.size()
c_in = x.view(batch_size * timesteps, C, H, W)
c_out = self.cnn(c_in)
r_out, (h_n, h_c) = self.rnn(c_out.view(-1,batch_size,c_out.shape[-1]))
logits = self.classifier(r_out)
return logits