Mask-RCNN + nn.DataParallel: possible?

Is it possible to use the MaskRCNN network on multiple GPUs?

import torch
import torchvision

device = 'cuda:0'

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = torch.nn.DataParallel(model)
model.to(device)
model.eval()

x = [torch.rand(3, 300, 400).to(device), torch.rand(3, 500, 400).to(device)]
preds = model(x)

This code give the following error:

Traceback (most recent call last):
  File "test_parallel.py", line 15, in <module>
    preds = model(imgs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/generalized_rcnn.py", line 47, in forward
    images, targets = self.transform(images, targets)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/transform.py", line 40, in forward
    image = self.normalize(image)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/transform.py", line 55, in normalize
    return (image - mean[:, None, None]) / std[:, None, None]
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

But if I comment the model = torch.nn.DataParallel(model) line it works.

There is an github issue on this subject and it seems to come from the way the DataParallel module works and how MaskRCNN methods are defined (if I have understood correctly).

But would there be a solution without making major changes?

torch 1.2.0
torchvision 0.4.0

using multiple GPUs requires all images have the same shape in a single batch.
so you should perform preprocessing like resize to your images.

Ok! But it still not work…

x = torch.rand(2, 3, 256, 256)
x = x.to(device)

preds = model(x)

Give the error:

Traceback (most recent call last):
  File "test_parallel.py", line 34, in <module>
    preds = model(x)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/generalized_rcnn.py", line 51, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/rpn.py", line 409, in forward
    proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/_utils.py", line 168, in decode
    rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
  File "/opt/anaconda3/envs/alp/lib/python3.7/site-packages/torchvision/models/detection/_utils.py", line 199, in decode_single
    pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
RuntimeError: expected device cuda:0 and dtype Float but got device cuda:1 and dtype Float

It is weird because if I replace the Mask-RCNN with torchvision.models.squeezenet1_1(), it work perfectly.