Neural network architecture suggestion(s) for video prediction (image sequence)

Hello everybody! Could you please provide me neural network architecture suggestion(s) for video prediction (image sequence) regarding entering 144 images and predicting 48 images for each sequence? For illustration, I tested the Wavenet and CNNLSTM networks contained in Bitbucket (but not only them) and I didn’t get good results. Thanks in advance!

If you have good computation resources, You can try vision transformers. There are multiple papers in the literature.

How I’d approach this is basically modify a 2d UNet to use all 3d layers, and possibly adding in a 4d branch internally to the model, to assist with modeling 3d objects/scenes as they change over time(3d+1d=4d).

Here is an example of a standard 2d Unet as well as a 2d + 3d Unet hybrid:

You’ll want to ensure your “channels” are still RBG and not the time sequence dimension. Also, adjust the kernel sizes throughout so it accepts 144 into the encoder and gives 48 out on the decoder on the time sequence dim.

Hi. Ok, I’ll look into that, thanks!

Hi. Ok, I took a look into that…

I’m also currently testing a UNET-LSTM network according to the code below, but it also didn’t show good results so far, it would be very different from the UNET you suggested, besides the fact that it doesn’t have LSTM, right?

 def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UnetLSTM(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers, kernel_size, dilation, stride, dropout):
        super(UnetLSTM, self).__init__()

        self.in_channels = in_channels  # Corresponds to input size
        self.out_channels = out_channels  # Corresponds to hidden size
        self.num_layers = num_layers 
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.stride = stride
        self.dropout = dropout

        #self.conv1 = nn.Conv3d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1)
        #self.conv2 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)

        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        #self.down3 = DoubleConv(128, 256)
        #self.pool3 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        #self.down4 = DoubleConv(256, 512)
        #self.pool4 = nn.MaxPool3d(kernel_size=(1, 2, 2))

        #self.lstm1 = nn.LSTM(out_channels, hidden_channels, num_layers=num_layers, batch_first=True)
        #self.lstm2 = nn.LSTM(out_channels, hidden_channels, num_layers=num_layers, batch_first=True)
        #self.lstm3 = nn.LSTM(out_channels, hidden_channels, num_layers=num_layers, batch_first=True)
                
        self.lstm2 = crnn.Conv2dLSTM(in_channels=128,  
                                      out_channels=128,  
                                      kernel_size=self.kernel_size,  
                                      num_layers=self.num_layers, 
                                      bidirectional=True,
                                      dilation=self.dilation, 
                                      stride=self.stride, 
                                      dropout=self.dropout, 
                                      batch_first=True) 

        self.conv2 = nn.Conv3d(128*2, 128, kernel_size=1)

        self.up1 = nn.ConvTranspose3d(512, 256, kernel_size=(1, 2, 2), stride=(1, 2, 2))
        self.up_conv1 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose3d(256, 128, kernel_size=(1, 2, 2), stride=(1, 2, 2))
        self.up_conv2 = DoubleConv(256, 128)
        self.up3 = nn.ConvTranspose3d(128, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2), output_padding=(0, 1, 1))
        self.up_conv3 = DoubleConv(128, 64)
        self.out_conv = nn.Conv3d(64, out_channels, kernel_size=1)

    def forward(self, x):
        #x = self.conv1(x)
        #print(x.shape)
        #x = self.conv2(x)
         
        # Encoder
        x1 = self.down1(x)
        x2 = self.pool1(x1)
        x2 = self.down2(x2)
        #x3 = self.pool2(x2)
        #x3 = self.down3(x3)
        #x4 = self.pool3(x3)
        #x4 = self.down4(x4)

        # LSTM 

        x2 = x2.permute(0, 2, 1, 3, 4)
        x2, _ = self.lstm2(x2)
        x2 = x2.permute(0, 2, 1, 3, 4)
        x2 = self.conv2(x2)

        # Decoder
        #x = self.up1(x4)
        #x = torch.cat([x, x3], dim=1)
        #x = self.up_conv1(x)

        #x = self.up2(x)
        #x = torch.cat([x, x2], dim=1)
        #x = self.up_conv2(x)

        x = self.up3(x2)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv3(x)
        
        x = self.out_conv(x)        
        x = x.reshape(x.size(0), x.size(1), -1, x.size(3), x.size(4))
                
        #return x       
        return x[:, :, -48:, :, :] 

My images always have one channel (pixel values range from 0 to 255).
By the way, do you see problems in this code I showed?

Thank you very much!

I’m not at a computer, so unable to test, but will point a few things I noticed.

  1. Should be 2d UNet LSTM or 3d UNet. The 3d Unet eliminates the need for an LSTM since you’re feeding the entire time sequence at once(albeit more calculation intensive).
  2. Self attention has good success in UNets. It helps the model focus on what’s important. Especially helps with skip connections, which I see you’ve included.

Have you applied normalization between 0 and 1 for the inputs?

Ok, considering I’m using 3D UNET and the input format I’m using is [batch_size=x, channels=1, size=144, width=7, height=7], would I be able to adapt my code to use 2D UNET to continue using LSTM?

Yes, I am normalizing the data into 0-1.

About self attention, thanks for the tip, I’ll add that.

Thank you very much!

A 3D UNet or Vision Transformer would be superior to an RNN such as an LSTM.

Consider this, what is the highest accuracy you’ve seen a model attain? The LSTM has to learn what information to pass on and what to discard. Let’s suppose, after much training, it does so correctly 80% of the time. That is an 80% memory accuracy(which for an ML model is pretty good). That means by the 4th time sequence, you’ve lost upward of 59% (1-0.8^4) fidelity.

This is why language models before 2017 were not very good, compared to today. Why settle for 80% memory capture between time frames when you can just give a model the data at 100% fidelity?

By passing in all time frames, as with a 3D UNet, you no longer need that middle LSTM layer. You can pass the data directly from encoder to decoder.

Hi! Right, I’m testing now the 3D Unet without the LSTM. I’m using this code, what do you think?

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Unet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Unet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.conv2 = nn.Conv3d(128, 128, kernel_size=1)
        self.up3 = nn.ConvTranspose3d(128, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2), output_padding=(0, 1, 1))
        self.up_conv3 = DoubleConv(128, 64)
        self.out_conv = nn.Conv3d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.pool1(x1)
        x2 = self.down2(x2)
        x2 = self.conv2(x2)
        x = self.up3(x2)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv3(x)
        x = self.out_conv(x)
        x = x.reshape(x.size(0), x.size(1), -1, x.size(3), x.size(4))
        return x[:, :, -48:, :, :]

Thank you so much for all the information and help you have given me!

That looks better.

I noticed you applied a skip connection near the beginning of the decoder but nowhere else. I’m guessing that was due to sizing issues since your output size is not the same as your input size. One way you can work around that is by just running a given skip tensor from the encoder through an nn.AdaptiveMaxPool3d or nn.AdaptiveAvgPool3d layer to get dim=2 to the right size needed. That will allow you to still pass those skip tensors for better fidelity when much of a scene remains the same in the output as it is in the input.

Hi. Ok, nice! About the skip connection, I think you’re talking about this line:

x = torch.cat([x, x1], dim=1)

I’m doing this just to allow low-level information captured in the first layers of the network to be preserved and combined with high-level information in subsequent layers, to help improve the model’s ability to capture both fine detail and more abstract level features. But I’ll try to do it the way you suggested.

I made a version of the Vision Transformer network as below, what do you think? I haven’t tested it yet.

import torch
import torch.nn as nn
from torch.nn import Transformer

class VisionTransformer(nn.Module):
    def __init__(self, in_channels, out_channels, seq_length=144, embed_dim=256, num_heads=8, num_layers=6):
        super(VisionTransformer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.seq_length = seq_length
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers

        self.embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_length, embed_dim))
        self.transformer = Transformer(
            d_model=embed_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers
        )
        self.fc = nn.Linear(embed_dim, out_channels)
        self.output_conv = nn.Conv3d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        bs, _, _, h, w = x.shape
        x = x.view(-1, self.in_channels, h, w)

        x = self.embedding(x)

        # Adjust pos_embedding size to match x
        pos_embedding = self.pos_embedding.repeat(bs, h, w, 1)

        x += pos_embedding

        x = x.permute(0, 2, 3, 1)
        x = x.view(bs, -1, self.embed_dim)

        x = self.transformer(x, x)

        x = x.view(bs, h, w, self.embed_dim)
        x = x.permute(0, 3, 1, 2)

        x = self.fc(x)

        x = x.view(bs, self.out_channels, -1, h, w)

        x = self.output_conv(x)

        x = x.view(bs, -1, h, w)

        return x

Thanks!

Regarding ViTs, I’ve only seen them used well in classification type problems. Not in image generation. If you go that route, you may need to read up on a few papers with keywords ViT and image generation. From what I have read, you may need to add a positional encoder.