Error in DistributedDataParallel when parameters are torch.cfloat

I’m trying to implement torch’s DistributedDataParallel for my model,

model = nn.parallel.DistributedDataParallel(model)

In one of the model layers I have a complex-valued nn.Parameter,

weights =  nn.Parameter(self.scale * torch.rand(dim1, dim2, dtype=torch.cfloat))

When I wrap it with nn.parallel.DistributedDataParallel, the following run-time error appears:

File "python-3.10.0/lib/python3.10/site-packages/torch/distributed/utils.py", line 131, in _sync_params_and_buffers
    dist._broadcast_coalesced(
RuntimeError: Invalid scalar type

However, when I make the weights tensor with torch.float:

weights =  nn.Parameter(self.scale * torch.rand(dim1, dim2, dtype=torch.float))

There are no errors, does anyone know how can I come up with this error?

The complex support in DDP is tracked here and not implemented yet.

Thank! Ok, I did a bit of mathematical manipulation and replaced the complex tensors with real ones. It’s working now