[solved] Concatenate time distributed CNN with LSTM

I’m working on building a time-distributed CNN. Originally, my code is implemented with Keras, and now I wanna porting my code to pytorch. Could someone give me some example of how to implement a CNNs + LSTM structure in pytorch?

The network structure will be like:

time1: image --cnn--|
time2: image --cnn--|---> (timestamp, flatted cnn output) --> LSTM --> (1, output unit)
time3: image --cnn--|

I also found some related code, but fail to construct with that code:
Timedistributed

1 Like

I assume that your input data is of shape (batch_size, timesteps, C, H, W)

Instead of TimeDistributed, you can use .view() to combine the batch and time dimensions before running the convolutions, then you can use .view() to separate the batch and time dimensions and to flatten the features before running the LSTM. Something like this…

class Combine(nn.Module):
    def __init__(self):
        super(Combine, self).__init__()
        self.cnn = CNN()
        self.rnn = nn.LSTM(320, 10, 2, batch_first=True)

    def forward(self, x):
        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.cnn(c_in)
        r_in = c_out.view(batch_size, timesteps, -1)
        r_out = self.rnn(r_in)
        return F.log_softmax(r_out, dim=1)

This will require the LSTM to be initialised with the batch_first=True option.

7 Likes

TimeDistributed was not thought out for that type of input. You should reshape your input explicitly.

1 Like

@jpeg729 @miguelvr Thanks! I made the structure that I needed.

For people who might need it
(This is the sample code for constructing CNN LSTM and using MNIST for demonstration.)

# pytorch mnist cnn + lstm

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# parser.add_argument('--batch-size', type=int, default=64, metavar='N',
#                     help='input batch size for training (default: 64)')
# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
#                     help='input batch size for testing (default: 1000)')
# parser.add_argument('--epochs', type=int, default=10, metavar='N',
#                     help='number of epochs to train (default: 10)')
# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
#                     help='learning rate (default: 0.01)')
# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
#                     help='SGD momentum (default: 0.5)')
# parser.add_argument('--no-cuda', action='store_true', default=False,
#                     help='disables CUDA training')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                     help='random seed (default: 1)')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='how many batches to wait before logging training status')
# args = parser.parse_args()


class Args:
    def __init__(self):
        self.cuda = True
        self.no_cuda = False
        self.seed = 1
        self.batch_size = 50
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.momentum = 0.5
        self.log_interval = 10


args = Args()

args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
    batch_size=args.batch_size,
    shuffle=True,
    **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
    batch_size=args.test_batch_size,
    shuffle=True,
    **kwargs)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        #x = F.relu(self.fc1(x))
        #x = F.dropout(x, training=self.training)
        #x = self.fc2(x)
        #return F.log_softmax(x, dim=1)
        return x


class Combine(nn.Module):
    def __init__(self):
        super(Combine, self).__init__()
        self.cnn = CNN()
        self.rnn = nn.LSTM(
            input_size=320, 
            hidden_size=64, 
            num_layers=1,
            batch_first=True)
        self.linear = nn.Linear(64,10)

    def forward(self, x):
        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.cnn(c_in)
        r_in = c_out.view(batch_size, timesteps, -1)
        r_out, (h_n, h_c) = self.rnn(r_in)
        r_out2 = self.linear(r_out[:, -1, :])
        
        return F.log_softmax(r_out2, dim=1)


model = Combine()
if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data = np.expand_dims(data, axis=1)
        data = torch.FloatTensor(data)
        if args.cuda:
            data, target = data.cuda(), target.cuda()
            

        
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))


def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        
        data = np.expand_dims(data, axis=1)
        data = torch.FloatTensor(data)
        print(target.size)
        
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.nll_loss(
            output, target, size_average=False).data[0]  # sum up batch loss
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))


for epoch in range(1, args.epochs + 1):
    train(epoch)
    test()
16 Likes

Why would you multiply the time and batch together instead of time and channels?

Hi, thank you for the code, but I have problem when I tried to implement, my input is (n_sample, n_channel, n_length) for Conv1d, for forward function when I use r_out, hidden = self.rnn(r_in) it will always generate error like this

File “”, line 18, in
output = model(data)

File “C:\Users\User-J\Anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 491, in call
result = self.forward(*input, **kwargs)

File “C:/Users/User-J/building_framework/idnet_cnn_rnn.py”, line 126, in forward
x, h = self.rnn(x)

File “C:\Users\User-J\Anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 491, in call
result = self.forward(*input, **kwargs)

File “C:\Users\User-J\Anaconda3\lib\site-packages\torch\nn\modules\rnn.py”, line 178, in forward
self.check_forward_args(input, hx, batch_sizes)

File “C:\Users\User-J\Anaconda3\lib\site-packages\torch\nn\modules\rnn.py”, line 126, in check_forward_args
expected_input_dim, input.dim()))

RuntimeError: input must have 3 dimensions, got 2

I guess because when training the output = model(data) only, is it suppose to be output, _ = model(data,None) but still this can’t get rid of the error, maybe you can help, because I need to use 1d conv and lstm.

Thank you

If you initialized your LSTM with the default settings, your input should have the shape [seq_len, batch, input_size]. Could you check that by printing its shape before passing it to self.rnn?

thank you for the reply, but I thought before we passed it into rnn we need to flatten it first?
I thought my problem is when I try to fit the data into the model, somehow the configuration also need to pass the hidden for lstm means can’t be just output = model(data) but should be output, hidden = model(data,hidden) and this will make def forward to have 2 param such as def forward(self, input, hidden)?
and because I am new with this, can you explain more about the difference of seq_len and batch_size?

check this solution

Thanks a lot for the detailed explanation. I have a silly confusion: as batch samples and timesteps are squashed, won’t it have any problem in LSTM sequential learning? i.e when the sequence is reshaped to (samples, timesteps, output_size), will it retain the sequential (timesteps) features ordering for each sample as it was before squashing?

@Jacky_Liu
r_out2 = self.linear(r_out[:, -1, :])
after this arent u loosing the original batch size of ,batch_sizenum of steps
So we will have target input size mismatch when calculate the loss
our target size is still batch size * num of steps
1

@ptrblck any opinion of yours here… i built it this way ,dfference in every discussion it is assumed that dataset input is seq ,bs,c,h,w but i simply do view . Last output of lstm in every forum i seen is seq len,bs,no of element feature.
My doubt is with this shape how can we compute the loss?
M i using lstm well with this architecture?

 (head): Head(
    (fc): Sequential(
      (0): GeM(p={:.4f}
      (1): Flatten()
      (2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=2048, out_features=1024, bias=True)
      (5): Mish()
    )
  )
  (lstm): lstm(
    (lstm): LSTM(1024, 512, num_layers=2, batch_first=True)
    (drop): Dropout(p=0.5, inplace=False)
  )
  (linear): Linear(in_features=512, out_features=2, bias=True)
)
class Rnxt(nn.Module):
    def __init__(self, arch='resnext50_32x4d_ssl', n=1, pre=True, ps=0.5):
        super().__init__()
        m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', arch,True)
        #m = torch.load('/home/temp/.cache/torch/checkpoints/semi_supervised_resnext50_32x4-ddb3e555.pth', arch)
        
        self.layer0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.layer1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
            m.layer1)
        self.layer2 = m.layer2
        self.layer3 = m.layer3
        self.layer4 = m.layer4
        
        nc = 2048
        self.head = Head(nc,n,ps=ps)
        self.lstm=lstm(16,12)#seq len and num of frames
        self.linear=nn.Linear(512,2)
    def forward(self, x):    
        #x =  F.interpolate(x, scale_factor=2, mode='bilinear')
        shape=x.shape[0]
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.head(x)
        #print('x',x.size())
        x=self.lstm(x).contiguous()
        #print('x',x.size())
        
        x=x.view(x.shape[0]*x.shape[1],512).contiguous()# 
        #print(x.size())
        if x.size(0)!=shape: #incase last few batches have different sizes
            x=x[0:shape,:]
        x=self.linear(x)
        
        return x

#lstm fwd

self.lstm = nn.LSTM(1024, 512 ,2,batch_first=True) #assume 16*12,1024 input
 def forward(self,input):
        #print(input.size(0))
        if input.size(0)!=self.batch*self.num:
            input=torch.cat([input,input[0:self.batch*self.num-input.size(0),:]],dim=0)
        x=input.view(self.batch,self.num,input.size(1)).contiguous()
         
        x,y=self.lstm(x)
        
        #x = x[:,-1,:]
        
        x=self.drop(x)
        return x

If you are using batch_first=True in your LSTM, the input shape should be [batch_size, seq_len, nb_features].
I’m not sure, what self.num refers to, but if you want to swap dimensions, you would have to use permute instead of view.

Not sure about why this ( reshaping before and after CNN) works, but at least tensorflow 1.x has an identical implementation: https://github.com/tensorflow/tensorflow/blob/05ab6a2afa2959410d48aab2336cea9dc1e2c13e/tensorflow/python/keras/layers/wrappers.py#L237

1 Like

hi,
I want to build a CNN (resnet pretrained ) + LSTM.
My input to the CNN is sequence of 9 videos frames and batch size 10.
if i use 1 video frame as input to the CNN (batch_size, C, H, W) and I can assume input data is of shape (batch_size, timesteps, C, H, W) to LSTM.
but How can be shape of input data if its a sequence of video frames.
Please help me.