Data Parallel loses ParameterDict

Hey, is it expected for ParameterDicts to become empty in DataParallel.

Minimal Example:

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.mid = nn.ParameterDict()
        self.mid["key1"] = nn.Parameter(torch.tensor(2.0))
        self.mid["key2"] = nn.Parameter(torch.tensor(3.0))
    def forward(self, x):
        return self.mid["key1"](x) + self.mid["key2"]

print("Devices", torch.cuda.device_count())
mod = MyModule()
mod = nn.DataParallel(mod)
mod =

input = torch.randn(2, 100, device=0)
out = mod(input)


Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/", line 86, in parallel_apply
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/", line 428, in reraise
    raise self.exc_type(msg)
KeyError: Caught KeyError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/", line 61, in _worker
    output = module(*input, **kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<stdin>", line 9, in forward
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/", line 542, in __getitem__
    return self._parameters[key]
KeyError: 'key1'

You had mis-spelled self.mid as self.min.

Thanks for catching that. Even with the spelling error corrected, the error about key1 remains. If you see in the print outs, the ParameterDict() is empty.

UserWarning: nn.ParameterDict is being used with DataParallel but this is not supported. 
This dict will appear empty for the models replicated on each GPU except the original one.

It seems nn.ParameterDict is not supported with DataParallel.