my code look like this
can a simple conv2d layer replace TimeDistributed layer
import torch
import torch.nn as nn
import torchvision
class TimeDistributed(nn.Module):
def __init__(self, module, batch_first=False):
super(TimeDistributed, self).__init__()
self.module = module
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = self.module(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
class Residual(nn.Module):
def __init__(self, in_channels, filters):
super(Residual, self).__init__()
self.in_channels = in_channels
self.f1 = filters[0]
self.f2 = filters[1]
self.f3 = filters[2]
self.maxpool = nn.MaxPool2d(kernel_size=(2,2))
self.residual = nn.Sequential(
nn.Conv2d(self.in_channels, self.f1, kernel_size=(3, 3),padding="same"),
nn.BatchNorm2d(self.f1),
nn.ReLU(),
nn.Conv2d(self.f1,self.f2, kernel_size=(1,1), padding='same'),
nn.BatchNorm2d(self.f2),
nn.ReLU(),
nn.Conv2d(self.f2,self.f3, kernel_size=(3, 3),padding="same"),
nn.BatchNorm2d(self.f3),
)
self.shortcut = nn.Sequential(
nn.Conv2d(self.in_channels, self.f3, kernel_size=(3,3), padding='same', stride=(1,1)),
nn.BatchNorm2d(self.f3),
)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
pool = self.maxpool(x)
x = self.residual(pool)
y = self.shortcut(pool)
x = (x + y ) / 2.0
x = self.activation(x)
return x
# SFTT model
class SFTT(nn.Module):
def __init__(self):
super(SFTT, self).__init__()
self.conv2d1= nn.Conv2d(30,32,kernel_size=(1,1), stride=(1,1),padding='same')
self.btch1 = nn.BatchNorm2d(32)
self.act1=nn.ReLU()
self.res1 = Residual(in_channels=32, filters=[32, 32, 128])
self.res2 = Residual(in_channels=128, filters=[64, 64, 256])
self.res3 = Residual(in_channels=256,filters=[128, 128, 512])
self.sf_output = nn.MaxPool2d(kernel_size=2)
self.sf_outputfl = nn.Flatten()
self.lstms = nn.LSTM(512, 512, num_layers=2, batch_first=True)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(4 *512, 32*32)
def forward(self, x):
bs,map_height, map_width, feature_maps = x.shape
x = self.conv2d1(x)
x = self.btch1(x)
x = self.act1(x)
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
sf_output = self.sf_output(x)
sf_output = self.sf_outputfl(x)
x, _ = self.lstms(x)
x = self.flatten(x)
x = self.fc1(x)
return x.view(-1, 32, 32)