Here is the code I run:
import torch
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, vit_b_16
from torch.nn import DataParallel
return_nodes = {"heads": "0"}
# return_nodes = {"fc": "0"}
device = torch.device("cuda")
backbone = vit_b_16()
backbone = create_feature_extractor(backbone,
return_nodes=return_nodes)
model1 = DataParallel(backbone).to(device)
x = torch.rand(2, 3, 256, 256).to(device)
out = model1(x)
and the error message is
File "/home/ubuntu/models/test_param_error.py", line 60, in <module>
out = model1(x)
File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 172, in replicate
return replicate(module, device_ids, not torch.is_grad_enabled())
File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/parallel/replicate.py", line 148, in replicate
setattr(replica, key, param)
File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1206, in __setattr__
raise TypeError("cannot assign '{}' as parameter '{}' "
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'class_token' (torch.nn.Parameter or None expected)
My torch and torchvision version:
pytorch 1.11.0 py3.9_cuda11.5_cudnn8.3.2_0
torchvision 0.12.0 py39_cu115
Here is what I find in a multi-gpu setting.
- If the model is resnet50 (and the return node is some resnet50 layer), this code has no error. For ViTs, they have defined some parameters using
nn.Parameter (self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)))
. That is where the error occurs. -
create_feature_extractor
,nn.Parameter
,DataParallel
: either two of them would work together but not all three. - From the error message we can know that the program tries to assign some tensor to the parameter
class_token
. If we don’t usecreate_feature_extractor
, this would work. Why doescreate_feature_extractor
change the behavior?
Thanks for help.