I am having trouble with the video MViT models built in torchvision 0.14. Video MViT — Torchvision main documentation (pytorch.org)
Here is my code:
import torch
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights
model = mvit_v2_s(weights="DEFAULT")
model.eval()
transforms = MViT_V2_S_Weights.KINETICS400_V1.transforms()
input = transforms(torch.rand(1, 32, 3, 800, 600))
output = model(input)
But I got errors as follows:
Traceback (most recent call last):
File "c:\Users\clai\projects\DL_VA_Prediction\src\lib\models\mvit_model.py", line 8, in <module>
output = model(input)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 558, in forward
x, thw = block(x, thw)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 384, in forward
x_attn, thw_new = self.attn(x_norm1, thw)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 293, in forward
k, k_thw = self.pool_k(k, thw)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\clai\miniconda3\envs\torch_HCM_py39\lib\site-packages\torchvision\models\video\mvit.py", line 89, in forward
x = x.reshape((B * N, C) + thw).contiguous()
RuntimeError: shape '[1, 96, 8, 56, 56]' is invalid for input of size 4816896
Any idea how I can make it work?
Thanks