I’m trying to implement torch’s DistributedDataParallel for my model,
model = nn.parallel.DistributedDataParallel(model)
In one of the model layers I have a complex-valued nn.Parameter,
weights = nn.Parameter(self.scale * torch.rand(dim1, dim2, dtype=torch.cfloat))
When I wrap it with nn.parallel.DistributedDataParallel, the following run-time error appears:
File "python-3.10.0/lib/python3.10/site-packages/torch/distributed/utils.py", line 131, in _sync_params_and_buffers
dist._broadcast_coalesced(
RuntimeError: Invalid scalar type
However, when I make the weights tensor with torch.float:
weights = nn.Parameter(self.scale * torch.rand(dim1, dim2, dtype=torch.float))
There are no errors, does anyone know how can I come up with this error?