Great news! I think I finally got the result that I needed using the nn.ModuleList() !
class TimeDistributed(nn.Module):
def __init__(self, layer, time_steps, *args):
super(TimeDistributed, self).__init__()
self.layers = nn.ModuleList([layer(*args) for i in range(time_steps)])
def forward(self, x):
batch_size, time_steps, C, H, W = x.size()
output = torch.tensor([])
for i in range(time_steps):
output_t = self.layers[i](x[:, i, :, :, :])
output_t = y.unsqueeze(1)
output = torch.cat((output, output_t ), 1)
return output
And checked this by counting the number of parameters after using Conv2D on 100 time_steps:
x = torch.rand(20, 100, 1, 5, 9)
model = TimeDistributed(nn.Conv2d, time_steps = 100, 1, 8, (3, 3) , 2, 1 ,True)
output = model(x)
print(output.size()) ## (20, 100, 8, 3, 5)
print("number of parameters : ", sum([p.numel() for p in model.parameters()]))
## number of parameters : 8000 instead of 80
or with Batchnormalization:
x = torch.rand(20, 100, 8, 3, 5)
model = TimeDistributed(nn.BatchNorm2d, time_steps = 100, 8)```