Create_feature_extractor, nn.Parameter, DataParallel are not compatiable together

Here is the code I run:

import torch
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, vit_b_16
from torch.nn import DataParallel

return_nodes = {"heads": "0"}
# return_nodes = {"fc": "0"}
device = torch.device("cuda")
backbone = vit_b_16()
backbone = create_feature_extractor(backbone, 
        return_nodes=return_nodes)
model1 = DataParallel(backbone).to(device)
x = torch.rand(2, 3, 256, 256).to(device)
out = model1(x)

and the error message is

  File "/home/ubuntu/models/test_param_error.py", line 60, in <module>
    out = model1(x)
  File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 172, in replicate
    return replicate(module, device_ids, not torch.is_grad_enabled())
  File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/parallel/replicate.py", line 148, in replicate
    setattr(replica, key, param)
  File "/home/ubuntu/anaconda3/envs/torchenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1206, in __setattr__
    raise TypeError("cannot assign '{}' as parameter '{}' "
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'class_token' (torch.nn.Parameter or None expected)

My torch and torchvision version:

pytorch                   1.11.0          py3.9_cuda11.5_cudnn8.3.2_0
torchvision               0.12.0               py39_cu115

Here is what I find in a multi-gpu setting.

  1. If the model is resnet50 (and the return node is some resnet50 layer), this code has no error. For ViTs, they have defined some parameters using nn.Parameter (self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))). That is where the error occurs.
  2. create_feature_extractor, nn.Parameter, DataParallel: either two of them would work together but not all three.
  3. From the error message we can know that the program tries to assign some tensor to the parameter class_token. If we don’t use create_feature_extractor, this would work. Why does create_feature_extractor change the behavior?

Thanks for help.