How to parallel in GPU when finetuning

In my Finetune models , I wanna parallel my model, in multi-gpus, my code is shown below:

class FinetuneModel(nn.Module):
    def __init__(self, pretrained_model, ngpu = opt.gpuids):
        self.ngpu = ngpu
        super(FinetuneModel, self).__init__()
        self.features = pretrained_model
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 2048),
    def forward(self, x):
        gpuids = None
        if self.ngpu:
            gpuids = range(self.ngpu)
        features =  self.features(x)#self.features has already implemented data parallel
        return nn.parallel.data_parallel(self.classifier, features, device_ids=gpuids)

as far as I know , when doing

features =  self.features(x)#self.features.forward has already implemented data parallel
score = nn.parallel.data_parallel(self.classifier, features, device_ids = gpuids)

GPU first broadcast batch data to GPU0 and GPU1 , after executing self.features, pytorch copy result to GPU0. when executing self.classifier, pytorch again broadcast data to multi-gpus.
is there a pytorchic way that could reduce data-copy like this

score = nn.parallel.data_parallel([self.features,self.classifier], features, device_ids = gpuids)

which only does one broadcast

Maybe just wrap the features with nn.DataParallel like this?

Nice, thanks . I’ll try it later.
Another question: why in here, only the model.features is paralleled, not the whole model?

As AlexNet and VGG contain lots of parameters in the FC layers. Syncing params on these layers will have large overhead. It’s faster to compute the FC layers only on one GPU.

1 Like

I see. great thanks.

A Pythonic way to do a data parallel on a sequence of modules is to group them in a container, and use data parallel on that container. You could just remove the data parallel code from features. And if you really need to do it differently, we have gather, scatter, replicate and parallel_apply inside torch.nn.parallel. Just keep in mind that they’re not documented right now and they still might change.