I am trying to convert an ConvLSTM model from Keras to Pytorch. The Keras model summary looks like this

Capture
my code look like this
can a simple conv2d layer replace TimeDistributed layer

import torch
import torch.nn as nn
import torchvision
class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)

        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y
class Residual(nn.Module):
  def __init__(self, in_channels, filters):
    super(Residual, self).__init__()

    self.in_channels = in_channels
    self.f1 = filters[0]
    self.f2 = filters[1]
    self.f3 = filters[2]

    self.maxpool = nn.MaxPool2d(kernel_size=(2,2))

    self.residual = nn.Sequential(
      nn.Conv2d(self.in_channels, self.f1, kernel_size=(3, 3),padding="same"),
      nn.BatchNorm2d(self.f1),
      nn.ReLU(),

      nn.Conv2d(self.f1,self.f2, kernel_size=(1,1), padding='same'),
      nn.BatchNorm2d(self.f2),
      nn.ReLU(),

      nn.Conv2d(self.f2,self.f3, kernel_size=(3, 3),padding="same"),
      nn.BatchNorm2d(self.f3),

    )

    self.shortcut = nn.Sequential(
      nn.Conv2d(self.in_channels, self.f3, kernel_size=(3,3), padding='same', stride=(1,1)),
      nn.BatchNorm2d(self.f3),
    )


    self.activation = nn.ReLU(inplace=True)
  def forward(self, x):
    pool = self.maxpool(x)
    x = self.residual(pool)
    y = self.shortcut(pool)
    x = (x + y ) / 2.0
    x = self.activation(x)

    return x

# SFTT model
class SFTT(nn.Module):
  def __init__(self):
    super(SFTT, self).__init__()

    
    self.conv2d1= nn.Conv2d(30,32,kernel_size=(1,1), stride=(1,1),padding='same')
    self.btch1 = nn.BatchNorm2d(32) 
    self.act1=nn.ReLU()
    self.res1 = Residual(in_channels=32, filters=[32, 32, 128])
    self.res2 = Residual(in_channels=128, filters=[64, 64, 256])
    self.res3 = Residual(in_channels=256,filters=[128, 128, 512])
    self.sf_output = nn.MaxPool2d(kernel_size=2)
    self.sf_outputfl = nn.Flatten()

    self.lstms = nn.LSTM(512,  512, num_layers=2, batch_first=True)

    self.flatten = nn.Flatten()

    self.fc1 = nn.Linear(4 *512, 32*32)

  def forward(self, x):
    bs,map_height, map_width, feature_maps = x.shape
    x = self.conv2d1(x)
    x = self.btch1(x) 
    x = self.act1(x) 
    x = self.res1(x) 
    x = self.res2(x) 
    x = self.res3(x) 
    sf_output =  self.sf_output(x)
    sf_output =  self.sf_outputfl(x)
    

    x, _ = self.lstms(x)
    x = self.flatten(x)
    x = self.fc1(x)
    return x.view(-1, 32, 32)