Using DataParallel when the input to the model is a dict

Hello,
I am trying to train a multi-modal model on multiple GPUs using torch.nn.DataParallel.
However, I have multiple modalities so the input to the model is a dictionary.
Is there any way to make this work on multiple GPUs? As far as I’ve understood DataParallel only works if the input to the model is a tensor.

My input to the model looks like the following:

{
'modality_1': torch.Tensor((bs, *img_size)),
'modality_2':  torch.Tensor((bs, *img_size)),
'modality_3':  torch.Tensor((bs, *img_size)),
}

Sol 1: define your own class for your inputs and inplement the to() function

Sol2:

self.data = {k: v.to(device) for k, v in self.data.items()}

Ref:

1 Like

Hi @klory, thanks for your answer!
I’m not quite sure if I understand correctly: should I send each value in the dict to another GPU for parallel computing?

You only need to call nn.DataParallel to your model, the data will be automatically distributed to all GPUs as long as you called to()

Oh okay I see, looks easy :slight_smile:
I’m getting following error though on the line results = model(batch):

  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 156, in forward
    return self.gather(outputs, self.output_device)
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
    res = gather_map(outputs)
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
    return type(out)(((k, gather_map([d[k] for d in outputs]))
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in <genexpr>
    return type(out)(((k, gather_map([d[k] for d in outputs]))
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
    return type(out)(((k, gather_map([d[k] for d in outputs]))
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in <genexpr>
    return type(out)(((k, gather_map([d[k] for d in outputs]))
  File "miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'Laplace' object is not iterable

Is there anything I could be doing wrong?
I have applied following line to the batch:
batch = {k: v.to(torch.device('cuda')) for k, v in batch.items()}

You did not show the line that triggers the mistake, can you run it successfully on one GPU? or (CPU at least)?

When trying to run this with torch.device('cpu') I get the error:

  File "/src/MIMIC/mimic/run_epochs.py", line 94, in basic_routine_epoch
    results = model(batch)
  File "/miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/miniconda3/envs/mimic/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 147, in forward
    raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

how about GPU without nn.DataParallel?

@klory the code runs without problems on gpu without nn.DataParallel. The problem seems to be the torch.distributions.Laplace that I call in the forward pass. Is there any reason this should be a problem?

Also my model consists of different nn.Modules: an encoder and a decoder for each modality. Is this compatible with nn.DataParallel?

Don’t know too much about Laplace,

But two modules are OK

I have created a gist that reproduces the error. The problem is indeed the Laplace distribution from torch.distributions