Timedistributed CNN

Great news! I think I finally got the result that I needed using the nn.ModuleList() !

class TimeDistributed(nn.Module):
    def __init__(self, layer, time_steps, *args):        
        super(TimeDistributed, self).__init__()
        
        self.layers = nn.ModuleList([layer(*args) for i in range(time_steps)])

    def forward(self, x):

        batch_size, time_steps, C, H, W = x.size()
        output = torch.tensor([])
        for i in range(time_steps):
          output_t = self.layers[i](x[:, i, :, :, :])
          output_t  = y.unsqueeze(1)
          output = torch.cat((output, output_t ), 1)
        return output

And checked this by counting the number of parameters after using Conv2D on 100 time_steps:

x = torch.rand(20, 100, 1, 5, 9)

model = TimeDistributed(nn.Conv2d, time_steps = 100, 1, 8, (3, 3) , 2,   1 ,True)
output = model(x)

print(output.size())   ## (20, 100, 8, 3, 5)

print("number of parameters : ", sum([p.numel() for p in model.parameters()]))

## number of parameters :  8000                           instead of 80

or with Batchnormalization:


x = torch.rand(20, 100, 8, 3, 5)
model = TimeDistributed(nn.BatchNorm2d, time_steps = 100, 8)```
5 Likes