Moving mnist cnn encoder,decoder and transformer

the objective is:
Using the model defined below, blueline
The first 19 frames of videos in the MNIST dataset
using the last frame i, estimate the last frame i.
Model
The model to be used for the problem is generally 3
consists of parts. 2-dimensional attributes
CNN encoder architecture, where the CNN encoder architecture is learned
transformer decoder where motion data is learned
architecture and the learned gesture to the image
CNN decoder architecture in which it is translated. These architectures
are detailed below.
CNN encoder architecture:
It consists of 3 convolutional blocks. Each
in the block, a convolution layer, group
normalization and ReLU activation function in place
is receiving. Convolutional layers, padding
used 3x3 convolutional operation with stride 2
is done. Number of output channels in blocks in order
with 8, 16 and 16.
Transformer decoder architecture:
Contains 2 transformer decoders. The first
transformer decoder, CNN encoder architecture
the features generated in the last layer as input
and embedding for the frame to be estimated
vector of the second transformer. Second transformer decoder
in the second layer of the CNN encoder architecture
takes the generated features as input and predicts
another embedding vector of the frame to be
learns. To facilitate the training of the model, each
output of a transformer decoder, input to this decoder
with the last frame of the CNN featuren given as
sum to get the embedding vector.
1 transformer layer in each decoder
are. Transformer layers inside fully
The number of neurons in the connected layers is 512,
Use 8 as the number of attention heads.
Note: Use ready-made modules for this architecture.
Positional for transformer decoder inputs
Do not forget to code. Example code
CNN decoder architecture:
It consists of 3 convolutional blocks. Each
block, using bilinear interpolation, feature
size (width and height) is doubled.
Then, these features are padded using 3x3
convolution process. Final
each convolutional layer, except the block, group
by normalization and ReLU function
is followed. In the last block, sigmoid
use the activation function.
Number of output channels of the convolutional layers,
16, 8 and 1, respectively.
As an introduction to this architecture, the transformer decoder
outputs are given. As input to the first block, the first
embedding learned with transformer decoder
vectors (to be made 2-dimensional !). Second
as input to the block, both the output of the first block and
learned with the second transformer decoder
embedding vector will be given as input (Unet
architecture). In the third block, only
the output of the second block is given as input.
Training Parameters:
1- Moving MNIST dataset 10% test,
10% validation and 80% training
separate the batches.
2- Mini-batch size: 8
3- ADAM optimization algorithm
using the model with a learning rate of 0.01
Train up to 50 epochs. Error function
use binary cross entropy.
4- For each training epoch, the training and
print validation errors. Estimate
Show the actual and projected frame.

and my code is

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
#from Seq2Seq import Seq2Seq
from torch.utils.data import DataLoader

import io
import imageio
from ipywidgets import widgets, HBox

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.group_norm = nn.GroupNorm(1, out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.group_norm(x)
        x = self.relu(x)
        return x

class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(TransformerDecoderBlock, self).__init__()
        self.transformer = nn.TransformerDecoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=512
        )

    def forward(self, x, memory):
        x = self.transformer(x, memory)
        return x

class CNNDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNNDecoderBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.group_norm = nn.GroupNorm(1, out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        x = self.group_norm(x)
        x = self.relu(x)
        return x

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        # CNN Enkoder Mimarisi
        self.cnn_encoder = nn.Sequential(
            CNNBlock(3, 8),
            CNNBlock(8, 16),
            CNNBlock(16, 16)
        )

        # Transformer Dekoder Mimarisi
        self.transformer_decoder_1 = TransformerDecoderBlock(embed_dim=16, num_heads=8)
        self.transformer_decoder_2 = TransformerDecoderBlock(embed_dim=16, num_heads=8)

        # CNN Dekoder Mimarisi
        self.cnn_decoder_1 = CNNDecoderBlock(in_channels=16, out_channels=8)
        self.cnn_decoder_2 = CNNDecoderBlock(in_channels=16, out_channels=1)

    def forward(self, x):
        # CNN Enkoder
        cnn_encoder_output = self.cnn_encoder(x)

        # Transformer Dekoder
        transformer_decoder_1_output = self.transformer_decoder_1(cnn_encoder_output, cnn_encoder_output)
        transformer_decoder_2_output = self.transformer_decoder_2(cnn_encoder_output, cnn_encoder_output)

        # Embedding Vektörleri
        embedding_1 = transformer_decoder_1_output[:, -1, :, :].view(x.size(0), -1)
        embedding_2 = transformer_decoder_2_output[:, -1, :, :].view(x.size(0), -1)

        # CNN Dekoder
        cnn_decoder_1_output = self.cnn_decoder_1(embedding_1.view(x.size(0), 16, 8, 8))
        cnn_decoder_2_output = self.cnn_decoder_2(torch.cat([cnn_decoder_1_output, embedding_2.view(x.size(0), 16, 8, 8)], dim=1))

        return cnn_decoder_2_output

# Modeli oluştur
#model = MyModel()

# Modeli yazdır
#print(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Data as Numpy Array
MovingMNIST = np.load('./datasets/mnist_test_seq.npy').transpose(1, 0, 2, 3)

# Shuffle Data
np.random.shuffle(MovingMNIST)

# Train, Test, Validation splits
train_data = MovingMNIST[:8000]         
val_data = MovingMNIST[8000:9000]       
test_data = MovingMNIST[9000:10000]     

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)     
    batch = batch / 255.0                        
    batch = batch.to(device)                     

    # Randomly pick 10 frames as input, 11th frame is target
    rand = 19 #np.random.randint(10,20)                     
    return batch[:,:,rand-10:rand], batch[:,:,rand]     


# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=16, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, 
                        batch_size=16, collate_fn=collate)


# Get a batch
input, _ = next(iter(val_loader))

# Reverse process before displaying
input = input.cpu().numpy() * 255.0     

for video in input.squeeze(1)[:3]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=5)
        display(HBox([widgets.Image(value=gif.getvalue())]))


# The input video frames are grayscale, thus single channel
model =  MyModel() # Seq2Seq(num_channels=1, num_kernels=64, 
#kernel_size=(3, 3), padding=(1, 1), activation="relu", 
#frame_size=(64, 64), num_layers=3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction='sum')

num_epochs = 1

for epoch in range(1, num_epochs+1):
    
    train_loss = 0                                                 
    model.train()                                                  
    for batch_num, (input, target) in enumerate(train_loader, 1):  
        output = model(input)                                     
        loss = criterion(output.flatten(), target.flatten())       
        loss.backward()                                            
        optim.step()                                               
        optim.zero_grad()                                           
        train_loss += loss.item()                                 
    train_loss /= len(train_loader.dataset)                       

    val_loss = 0                                                 
    model.eval()                                                   
    with torch.no_grad():                                          
        for input, target in val_loader:                          
            output = model(input)                                   
            loss = criterion(output.flatten(), target.flatten())   
            val_loss += loss.item()                                
    val_loss /= len(val_loader.dataset)                            

    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))
    
def collate_test(batch):

    # Last 1 frames are target
    target = np.array(batch)[:,1:]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)          
    batch = batch / 255.0                             
    batch = batch.to(device)                          
    return batch, target

# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True, 
                         batch_size=3, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
for timestep in range(target.shape[1]):
  input = batch[:,:,timestep:timestep+10]   
  output[:,timestep]=(model(input).squeeze(1).cpu()>0.5)*255.0

for tgt, out in zip(target, output):       # Loop over samples
    
    # Write target video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, tgt, "GIF", fps = 5)    
        target_gif = gif.getvalue()

    # Write output video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, out, "GIF", fps = 5)    
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))

Getting error :
[157] output = model(input)
[75] cnn_encoder_output = self.cnn_encoder(x)

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [16, 1, 10, 64, 64]

any ideas ?
Thanks