DistributedDataParallel (DDP) and Multi-task Learning (MTL) where model outputs might not contribute to the loss

Hello. I am interested in multi-task learning (MTL) and encountered some difficulties when trying the distributed API with inspiration from the TorchVision examples: vision/references/detection at master · pytorch/vision · GitHub.
The short description is that the distributed forward call fails when one or more task-models (or heads) do not contribute to the total loss of the previous iteration:

Traceback (most recent call last):
  File "src/main.py", line 160, in <module>
  File "src/main.py", line 148, in main
    step = multi_model.train_one_epoch(dataloader_train, optimizer, epoch=epoch, step=step, write_freq=200, writer=writer)
  File "/workspace/src/models/multi_model.py", line 103, in train_one_epoch
    losses = self(images, targets)
  File "/workspace/src/models/multi_model.py", line 87, in __call__
    model_out = model(images, targets)
  File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 804, in forward
    if grad_enabled and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 2: 16 17 18 19 20 21 22 23 24 25 26 27 28 29
 In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 11861) of binary: /opt/conda/bin/python
ERROR:torch.distributed.elastic.agent.server.local_elastic_agent:[default] Worker group failed

Each head is only trained with a subset of the training data which is filtered from the batch during training. If the batch do not include an input to one or more heads, those heads do not (or cannot) contribute to the loss. I tried to circumvent this by “faking” a forward call of the task but then the loss will potentially not be “good”.

The question is probably two-fold:

  1. is such an input-filtering approach correct when doing MTL and
  2. how would you still call the forward of unused heads?

One solution could be to force a batch to include input for each head but if we would use multiple heads, the batch size might become problematic with the limited resources I have. Another solution would be to calculate the loss in a different way and account for “false” examples correctly but this probably requires a lot of custom code.

A similar question has been asked before: Process got stuck when set find_unused_parameters=True in DDP - #3 by oliver_ss and the solutions seems similar to what I have done by faking a forward pass. However, it remains to be said if this approach is sound and good.

The longer description is that I have tried to use a single pretrained backbone (ResNet) and multiple TorchVision detection heads (Faster-RCNN) to detect different objects. Given that there is no joint dataset for these objects, one naive approach is to use multiple heads. I have managed to get the “combined model” to train without the distributed API but it would be desirable to get it to run on a machine with 4 GPUs that I have available.

Each head gives a FasterRcnn loss dict which we reduce with the approach given by: vision/references/detection at master · pytorch/vision · GitHub. But as said above in the short description: the next forward pass will encounter a lot of “hanging” parameters when the filtering “fails”.

This approach with multiple TorchVision detection heads might not be the best approach, and I would gladly take hints or examples of other approaches as well as I cannot come up with good search terms… However, I feel that this naive approach is a good starting point.

For completeness I have included most of the relevant model code below. At this point it is not possible to include everything but it might give some hint of what I am trying to do.

Below is the code for setting up the combined model:

class MultiModel:
    """ An attempt on a modular approach where a single backbone supplies features for multiple
        task-specific networks. The FasterRcnn network calls the backbone networks forward function within its
        own forward. If multiple FasterRcnns are used the backbone features should in theory be reusable.
        Linear inference speed in the number of original networks might be reduced to
        constant backbone + linear task specific network inference speed.

    def __init__(self, backbone, device, distributed_device_ids=None):
        self.backbone = backbone
        self.model_list = []
        self.device = device
        self.distributed_device_ids = distributed_device_ids
        self.model_modules = [] # necessary to keep track of "undistributed models"

    def add_frcnn_model(self, frcnn_class, num_classes, model_label, frcnn_kwargs):
        if frcnn_kwargs == None:
            frcnn_kwargs = {}
        model = frcnn_class(self.backbone, num_classes, model_label,
        if self.distributed_device_ids:
            model = torch.nn.parallel.DistributedDataParallel(model, self.distributed_device_ids, find_unused_parameters=False)


where each head is currently only given by a FasterRCNNModel:

from torchvision.models.detection.faster_rcnn import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, load_state_dict_from_url, model_urls, overwrite_eps
from torchvision.transforms.transforms import Normalize

def custom_fasterrcnn_resnet50_fpn(backbone, pretrained=False, progress=True,
                            num_classes=3, **kwargs):
    """ Custom FasterRcnn model with provided backbone
    # we cannot overwrite num_classes here, we need to load the weights before changing the head
    model = FasterRCNN(backbone, num_classes=91, **kwargs) 
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
        overwrite_eps(model, 0.0)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

class FasterRCNNModel(nn.Module):
    """ Simplificates the construction of a Faster RCNN model from the torchvision detection model
        Seamless handling of input filtering so that only applicable input is processed
    def __init__(self, backbone, num_classes, model_label, frcnn_kwargs={}):
        super(FasterRCNNModel, self).__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        self.model_label = model_label
        self.frcnn_kwargs = frcnn_kwargs

        self.model = custom_fasterrcnn_resnet50_fpn(backbone, pretrained=True, progress=True,
                                                    num_classes=num_classes, **frcnn_kwargs)

    def forward(self, images: List, targets: List):
        """ Returns losses or detections depending on train or eval mode.
            Filters out the usable image, target pairs based on "model_label" in target
        if self.model.training:
            mask = self._get_mask(targets)
            mask = np.ones(len(images))
        n_inputs = len(images)
        masked_images = [images[i] for i in range(n_inputs) if mask[i]]
        masked_targets = [targets[i] for i in range(n_inputs) if mask[i]]

        n_masked_inputs = len(masked_images)
        if n_masked_inputs > 0:
            out = self.model(masked_images, masked_targets) 
            return out
            In a distributed setting, a empty loss dict results in a silent crash if we try to reduce the dicts across processes.
            However, tensors that are not from a computation graph lead to a crash at .backwards() as there are non finished
            reductions. How to solve?
            For now, we introduce a dummy pass through our model which we then zero out. All grad_fns should be there.

            device = torch.cuda.current_device()
            images = [torch.zeros((3, 256, 256), dtype=torch.float32, device=device)]
            boxes = torch.tensor(np.zeros((0, 4)), dtype=torch.float32, device=device)
            labels = torch.tensor(np.zeros((0)), dtype=torch.int64, device=device)
            targets = [{"boxes": boxes, "labels": labels}]
            out = self.model(images, targets)
            for k, v in out.items():
                v *= 0

            return out

Thank you for any discussion regarding this. I can also provide more code if needed but I think it will be difficult to provide a minimal working example. It is more of a question if this is the correct approach.

A similar question has been asked before: Process got stuck when set find_unused_parameters=True in DDP - #3 by oliver_ss and the solutions seems similar to what I have done by faking a forward pass. However, it remains to be said if this approach is sound and good.

@zalador what do you mean if this approach is sound and good?

I think I meant that if faking a forward pass is the correct approach and would work. I am not 100 % sure that just setting the gradients to 0 would lead to correct training.
And also if there are better approaches, this feels like a bandaid approach with DDP. Without DDP you do not need to do a faked forward pass.

if your model has unused parameters, you could set find_unused_parameters=True; if not all output tensors will be used to calculate for loss, DDP <= PT 1.9 can not support the case yet. But we’ve added a fix to support this case and will be released in PT 1.10, you can try this feature in PT nightly build for now.


Thank you for your input. I tried with the latest nightly build but I get version mismatches between PT and TorchVision which I use for my models. I will definitely look out for this in the future.