Usage issue with torchvision built-in MViT model

I am having trouble with the video MViT models built in torchvision 0.14. Video MViT — Torchvision main documentation (pytorch.org)

Here is my code:

import torch
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights

model = mvit_v2_s(weights="DEFAULT")
model.eval()
transforms = MViT_V2_S_Weights.KINETICS400_V1.transforms()
input = transforms(torch.rand(1, 32, 3, 800, 600))
output = model(input)

But I got errors as follows:

Traceback (most recent call last):
  File "c:\Users\clai\projects\DL_VA_Prediction\src\lib\models\mvit_model.py", line 8, in <module>
    output = model(input)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 558, in forward
    x, thw = block(x, thw)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 384, in forward
    x_attn, thw_new = self.attn(x_norm1, thw)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 293, in forward
    k, k_thw = self.pool_k(k, thw)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 89, in forward
    x = x.reshape((B * N, C) + thw).contiguous()
RuntimeError: shape '[1, 96, 8, 56, 56]' is invalid for input of size 4816896

Any idea how I can make it work?

Thanks

The weights are loaded into a model initialized with temporal length set to 16. Looks like the model does not handle dynamic temporal length input. It could be retrained with a longer temporal length if the dataset supports it.