Does PyTorch Support Backward From DataParallel to DataParallel?

For example:

gpus = [0, 1]
model_1 = DataParallel(model_1, device_ids=gpus)  # such as resnet without classifier
model_2 = DataParallel(model_2, device_ids=gpus)  # a classifier contains dropout and linear 
opt.zero_grad()  # SGD optimizer
feat = model_1(x)
print("forward model_1")
prob = model_2(feat)
print("forward model_2")
loss = CrossEntropyLoss(prob, y)
print("forward loss")
loss.backward()
print("backward")
opt.step()

This program works fine on my toy dataset, which has only 200+ classes.
But, when I put my large dataset, which has 2 million classes, it seems be stuck at backward step.
Print

forward model_1
forward model_2
forward loss

Addition:
nvidia-smi:

  • GPU-Util: 100%
  • Memory-Usage: about 15%

top:

  • %CPU: 30
  • %MEM: 40%

I don’t see why you wouldn’t just wrap model_1 and model_2 in a single object, i.e.,

class full_model(nn.Module):
    def __init__():
        self.model1 = model1
        self.model2 = model2
        ...
    def forward(self,x):
        return self.model2(self.model1(x))

which can then be wrapped in a single DataParallel object. Depending on how you have instantiated your opt your current approach might not be working as intended, since your opt needs to be initialised with the model parameters. Which is why you should write it as follows:

model = full_model(...)
opt = torch.optim.Adam(model.parameters(),lr=lr)
model_dp = DataParallel(model, device_ids=gpus)

This way your parameters will be correctly updated when calling opt.step() since they likely weren’t before.

@Jamie_Donnelly , Thanks for your reply.

I am building a modular project. When one use my project, there may be some other operators between model_1 and model_2.

Do you mean DataParallel must be after optim? I will try it, thank you again.:slight_smile:

I think it will still work if you initialise an optimizer after a DataParellel model, but personally I think it’s better practice to initialise the optimizer on the base model itself, i.e., using model.parameters() in the optim constructor rather than DataParallel(model).parameters().

Just make sure that your optim is initialised with the full model parameters, hence wrapping model_1 and model_2 into a single model.