Any PyTorch function can work as Keras' Timedistributed?

(Hongyuan Zhu) #1

Hi! I used to be a Keras user, I want to port my functions to PyTorch. Recently I work on a video classification problem, which uses a similar architecture as LRCN (, which applys CNN to extract features from each frame, then use LSTM for classification. In Keras, there is a timedistributed function ( which can apply a layer to each temporal slice, I wonder PyTorch has similar implementations or how I can achieve similar function in this case? Any existing PyTorch example for it?

Thanks in advance for your patience and help!!

(Thomas V) #2


from the top of my head, I think that the model in Sean Naren’s deepspeech.pytorch does something very similar to what you want to achieve with the SequenceWise class:

Best regards


(Hongyuan Zhu) #3

Hi, Tom. Thanks for your sharing! I’ll try to look into that!



(Miguel Varela Ramos) #4


I developed a PyTorch module that mimics the TimeDistributed wrapper of Keras a few days ago:

import torch.nn as nn

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)

        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y

[solved] Concatenate time distributed CNN with LSTM
Retrain_variables in the loss function
(Hongyuan Zhu) #5

Wow, cool! That’s pretty awwwwwesome!!!:grinning:

(Jacky Liu) #6

Could you give me some example on how to use this function to construct time distributed cnn + lstm?

Several images will be computed by CNN and feed to LSTM all together.

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        #x = F.relu(self.fc1(x))
        #x = F.dropout(x,
        #x = self.fc2(x)
        #return F.log_softmax(x, dim=1)
        return x

class Combine(nn.Module):
    def __init__(self):
        super(Combine, self).__init__()
        self.cnn = CNN()
        self.rnn = nn.LSTM(320, 10, 2)

    def forward(self, x):
        x = self.cnn(x)
        x = self.rnn(x)
        return F.log_softmax(x, dim=1)

(Mac Yeh) #7

Thanks for the sharing, I was thinking to loop the function, your implementation reminds me we are in OO environment; thanks a lot ~~~~~~

(Miguel Varela Ramos) #8

For most cases, this function is not needed anymore. The Dense layer now supports 3 dimensional inputs, for example.

(Mac Yeh) #9

@miguelvr you are right, right now the linear layer supports 3 dimensional inputs; thanks

(Ken Fehling) #10

Is putting a Dense layer after an RNN the same as applying a Dense layer to each time step though? Like in the first case don’t the time steps connect and mix together?

(Kota Mori) #11

@miguelvr Isn’t this still useful for other layers than Linear though? For example, the input tensor is of shape [sample, frame, image], like video, and you may want to apply a convnet module for each time frame. Please kindly correct me if I get this wrong.

(Miguel Varela Ramos) #12

Yes definitely, it still can be useful for other cases

(Satheesh) #13

thanks. I was looking for the timedistributed equivalent in pytorch and found your code…