Hello. I am interested in multi-task learning (MTL) and encountered some difficulties when trying the distributed API with inspiration from the TorchVision examples: vision/references/detection at master · pytorch/vision · GitHub.
The short description is that the distributed forward call fails when one or more task-models (or heads) do not contribute to the total loss of the previous iteration:
Traceback (most recent call last):
File "src/main.py", line 160, in <module>
main(args)
File "src/main.py", line 148, in main
step = multi_model.train_one_epoch(dataloader_train, optimizer, epoch=epoch, step=step, write_freq=200, writer=writer)
File "/workspace/src/models/multi_model.py", line 103, in train_one_epoch
losses = self(images, targets)
File "/workspace/src/models/multi_model.py", line 87, in __call__
model_out = model(images, targets)
File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 804, in forward
if grad_enabled and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
making sure all `forward` function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 2: 16 17 18 19 20 21 22 23 24 25 26 27 28 29
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 11861) of binary: /opt/conda/bin/python
ERROR:torch.distributed.elastic.agent.server.local_elastic_agent:[default] Worker group failed
Each head is only trained with a subset of the training data which is filtered from the batch during training. If the batch do not include an input to one or more heads, those heads do not (or cannot) contribute to the loss. I tried to circumvent this by “faking” a forward call of the task but then the loss will potentially not be “good”.
The question is probably two-fold:
- is such an input-filtering approach correct when doing MTL and
- how would you still call the forward of unused heads?
One solution could be to force a batch to include input for each head but if we would use multiple heads, the batch size might become problematic with the limited resources I have. Another solution would be to calculate the loss in a different way and account for “false” examples correctly but this probably requires a lot of custom code.
A similar question has been asked before: Process got stuck when set find_unused_parameters=True in DDP - #3 by oliver_ss and the solutions seems similar to what I have done by faking a forward pass. However, it remains to be said if this approach is sound and good.
The longer description is that I have tried to use a single pretrained backbone (ResNet) and multiple TorchVision detection heads (Faster-RCNN) to detect different objects. Given that there is no joint dataset for these objects, one naive approach is to use multiple heads. I have managed to get the “combined model” to train without the distributed API but it would be desirable to get it to run on a machine with 4 GPUs that I have available.
Each head gives a FasterRcnn loss dict which we reduce with the approach given by: vision/references/detection at master · pytorch/vision · GitHub. But as said above in the short description: the next forward pass will encounter a lot of “hanging” parameters when the filtering “fails”.
This approach with multiple TorchVision detection heads might not be the best approach, and I would gladly take hints or examples of other approaches as well as I cannot come up with good search terms… However, I feel that this naive approach is a good starting point.
For completeness I have included most of the relevant model code below. At this point it is not possible to include everything but it might give some hint of what I am trying to do.
Below is the code for setting up the combined model:
class MultiModel:
""" An attempt on a modular approach where a single backbone supplies features for multiple
task-specific networks. The FasterRcnn network calls the backbone networks forward function within its
own forward. If multiple FasterRcnns are used the backbone features should in theory be reusable.
Linear inference speed in the number of original networks might be reduced to
constant backbone + linear task specific network inference speed.
"""
def __init__(self, backbone, device, distributed_device_ids=None):
self.backbone = backbone
self.model_list = []
self.device = device
self.distributed_device_ids = distributed_device_ids
self.model_modules = [] # necessary to keep track of "undistributed models"
def add_frcnn_model(self, frcnn_class, num_classes, model_label, frcnn_kwargs):
if frcnn_kwargs == None:
frcnn_kwargs = {}
model = frcnn_class(self.backbone, num_classes, model_label,
frcnn_kwargs=frcnn_kwargs)
model.to(self.device)
if self.distributed_device_ids:
print(self.distributed_device_ids)
model = torch.nn.parallel.DistributedDataParallel(model, self.distributed_device_ids, find_unused_parameters=False)
self.model_modules.append(model.module)
self.model_list.append(model)
where each head is currently only given by a FasterRCNNModel
:
from torchvision.models.detection.faster_rcnn import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, load_state_dict_from_url, model_urls, overwrite_eps
from torchvision.transforms.transforms import Normalize
def custom_fasterrcnn_resnet50_fpn(backbone, pretrained=False, progress=True,
num_classes=3, **kwargs):
""" Custom FasterRcnn model with provided backbone
"""
# we cannot overwrite num_classes here, we need to load the weights before changing the head
model = FasterRCNN(backbone, num_classes=91, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
class FasterRCNNModel(nn.Module):
""" Simplificates the construction of a Faster RCNN model from the torchvision detection model
module.
Seamless handling of input filtering so that only applicable input is processed
"""
def __init__(self, backbone, num_classes, model_label, frcnn_kwargs={}):
super(FasterRCNNModel, self).__init__()
self.backbone = backbone
self.num_classes = num_classes
self.model_label = model_label
self.frcnn_kwargs = frcnn_kwargs
self.model = custom_fasterrcnn_resnet50_fpn(backbone, pretrained=True, progress=True,
num_classes=num_classes, **frcnn_kwargs)
def forward(self, images: List, targets: List):
""" Returns losses or detections depending on train or eval mode.
Filters out the usable image, target pairs based on "model_label" in target
"""
if self.model.training:
mask = self._get_mask(targets)
else:
mask = np.ones(len(images))
n_inputs = len(images)
masked_images = [images[i] for i in range(n_inputs) if mask[i]]
masked_targets = [targets[i] for i in range(n_inputs) if mask[i]]
n_masked_inputs = len(masked_images)
if n_masked_inputs > 0:
out = self.model(masked_images, masked_targets)
return out
else:
"""
In a distributed setting, a empty loss dict results in a silent crash if we try to reduce the dicts across processes.
However, tensors that are not from a computation graph lead to a crash at .backwards() as there are non finished
reductions. How to solve?
For now, we introduce a dummy pass through our model which we then zero out. All grad_fns should be there.
"""
device = torch.cuda.current_device()
images = [torch.zeros((3, 256, 256), dtype=torch.float32, device=device)]
boxes = torch.tensor(np.zeros((0, 4)), dtype=torch.float32, device=device)
labels = torch.tensor(np.zeros((0)), dtype=torch.int64, device=device)
targets = [{"boxes": boxes, "labels": labels}]
out = self.model(images, targets)
for k, v in out.items():
v *= 0
return out
Thank you for any discussion regarding this. I can also provide more code if needed but I think it will be difficult to provide a minimal working example. It is more of a question if this is the correct approach.