DistributedDataParallel Sparse Embeddings

Hi everyone,

I’m trying to train a matrix factorization model, wrapping the model with DistributedDataParallel. The model contains embeddings, which are multiplied together to obtain a prediction. After I calculate the MSELoss between the target and the prediction, calling backwards gives me the following error:

Traceback (most recent call last):
  File "src/train_model.py", line 188, in <module>
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)

-- Process 6 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/ubuntu/NOMAR/src/train_model.py", line 56, in run_training
    trainer.run(mf_model, train_data_dir, TEST_DATA_DIR, dataset_kwargs, early_stopping_kwargs)
  File "/home/ubuntu/NOMAR/src/trainer/MFTrainer.py", line 73, in run
    avg_training_loss = self.train_model(model, train_loader, optimizer)
  File "/home/ubuntu/NOMAR/src/trainer/MFTrainer.py", line 95, in train_model
  File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Tensors must be CUDA and dense

I’m suspecting that the sparse gradients cannot be joined, but I read from https://github.com/pytorch/pytorch/issues/17356, that it has been fixed?

Any insights would be appreciated.

Hi, I have the same problem, but I have found the solution. The pytorch have supported backend="gloo" while the support for nccl has not been completed. I have tested it in pytorch nightly. If the stable version still go wrong, you can update to it.