nn.DataParallel degrades the prediction performance

nn.DataParallel degrades the prediction performance when trying to localize text in images.
The below image is when I comment nn.DataParallel and it gives me the text files just fine.
Bad Detection

However when I parallelize the model the performance degrades and only predicts tiny boxes below the image.
Good Detection

Commenting the 3rd line does this effects.

       self.model = BasicModel(args)
        # for loading models
        self.model = parallelize(self.model, distributed, local_rank)
        self.criterion = SegDetectorLossBuilder(
            args['loss_class'], *args.get('loss_args', []), **args.get('loss_kwargs', {})).build()
        self.criterion = parallelize(self.criterion, distributed, local_rank)
        self.device = device
        self.to(self.device)

parallelize function looks like the following:

def parallelize(model, distributed, local_rank):
    if distributed:
        return nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=[local_rank],
            find_unused_parameters=True)
    else:
        return nn.DataParallel(model)

Something similar also happens when I train my model on GPU and try to use it on CPU but I am not sure why.

Are you using model.eval() before executing the prediction step?
If not, then note that the prediction results could depend on the batch size, which might be different on the single-GPU vs. data parallel setup, e.g. due to batchnorm layers.

Yes, I use model.eval().

I solved the issue. The problem was that the model was trained using DataParallel and when I was trying to load the state dictionary there was a mismatch between the dictionary keys of the model and the network even though I set the strict=False.

Once I remove the .module from the state_dict keys it solved the issue.

Good to hear, you’ve figured it out!

I would not recommend to use strict=False in case you are running into unexpected mismatches while loading the state_dict, as this will ignore the keys.

1 Like