RuntimeError: The size of tensor a (32769) must match the size of tensor b (512) at non-singleton dimension 3

I am trying to pretrain the ViViT model of the HuggingFace library (link) on my custom VideoDataset.
Each video of ymdataset has the following shape: Shape array: (10, 64, 512, 512). Basically they are 10 blocks of 64 frames of 512x512 grayscale.
Now, as the size of the dataset is big, as well as the dataset of the original vivit model, I wanted to process each of these blocks separately, and to use a smaller vivit model. This is the code I am running now:

import torch
import torch.nn as nn
import numpy as np
from transformers import VivitModel, VivitConfig
from torch.utils.data import DataLoader, Dataset
import os
import torchvision.transforms as transforms

class VideoDataset(Dataset):
    def __init__(self, video_dir, transform=None):
        # Load all .npy files from the specified directory
        self.videos = [np.load(os.path.join(video_dir, f)) for f in os.listdir(video_dir) if f.endswith('.npy')]
        self.transform = transform
        self.blocks = [(video,i) for video in self.videos for i in range(video.shape[0])]

    def __len__(self):
        return len(self.blocks)

    def __getitem__(self, idx):
        video, block_idx = self.blocks[idx]
        frames = video[block_idx]  # Shape: (frames_per_block, height, width)
        print(f"Original frames shape: {frames.shape}")

        # Reshape to (num_blocks * frames_per_block, height, width)
        frames = frames.reshape(64, 512, 512)  # Shape: (frames_per_block, 512, 512)

        # Convert to tensor and apply transformations
        frames = [torch.tensor(frame, dtype=torch.float32).unsqueeze(0) for frame in frames]
        if self.transform:
            frames = [self.transform(frame) for frame in frames]  # Apply normalization

        return torch.stack(frames)  # Shape: (num_blocks * frames_per_block, channels, height, width)

# Transformation settings
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalizes the tensor directly
])

# Load the dataset
video_dir = '../subset-video-numpy'
dataset = VideoDataset(video_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Configure the ViViT model
config = VivitConfig(
    num_frames=64,         # Number of frames per video
    num_channels=1,       # Number of channels (1 for grayscale)
    image_size=512,        # Size of the frame (512x512)
    hidden_size=512,
    num_hidden_layers=2,
    num_attention_heads=2
)
model = VivitModel(config)
model = nn.DataParallel(model)

# Loss function for Masked Autoencoder
class MaskedAutoencoderLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.MSELoss()

    def forward(self, outputs, targets, mask):
        loss = self.criterion(outputs * mask, targets * mask)
        return loss

# Random masking of patches
def random_mask(video, mask_ratio=0.15):
    mask = torch.rand(video.shape) > mask_ratio
    return mask.float()

# Pretraining loop
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = MaskedAutoencoderLoss()

epochs = 10

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for video in dataloader:
        video = video.to(device)

        # Apply random masking
        mask = random_mask(video).to(device)
        masked_video = video * mask

        print(f"Masked video shape: {masked_video.shape}")

        # Forward pass
        outputs = model(pixel_values=masked_video).last_hidden_state
        print(f"Outputs shape: {outputs.shape}")
        print(f"Video shape: {video.shape}")
        print(f"Mask shape: {mask.shape}")
        loss = criterion(outputs, video, mask)

        # Backward pass and weight update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

The error I have is the one in the title, when loss = criterion(outputs, video, mask)
Idk if the problem is in the change of the config options. Original ones were:

image_size (int, optional, defaults to 224) — The size (resolution) of each image.
num_frames (int, optional, defaults to 32) — The number of frames in each video.
tubelet_size (List[int], optional, defaults to [2, 16, 16]) — The size (resolution) of each tubelet.
num_channels (int, optional, defaults to 3) — The number of input channels.
hidden_size (int, optional, defaults to 768) — Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int, optional, defaults to 12) — Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 12) — Number of attention heads for each attention layer in the Transformer encoder.