Multi-gpu model with loss functions

Hi everybody

I’m getting familiar with training multi-gpu models in Pytorch. I found this official tutorial on best practices for multi-gpu training. I adapted the original code in order to return two predictions/outputs and use two losses afterwards. Predicted values are on separate GPUs, also note that the model uses 2x GPUs.
The code looks as follows:

import torch
import torch.nn as nn
import torch.optim as optim


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net0 = torch.nn.Linear(10, 10).to('cuda:0')
        self.relu = torch.nn.ReLU()
        self.net1 = torch.nn.Linear(10, 5).to('cuda:1')

    def forward(self, x):
        x_out0 = self.relu(self.net0(x.to('cuda:0')))
        x_out1 = self.net1(x_out0.to('cuda:1'))
        return x_out1, x_out0


if __name__ == '__main__':
    
    model = ToyModel()
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs1, outputs0 = model(torch.randn(20, 10))
    labels0 = torch.randn(20, 10).to('cuda:0')
    labels1 = torch.randn(20, 5).to('cuda:1')
    loss = loss_fn(outputs1, labels1) + loss_fn(outputs0, labels0) # DOES NOT WORK
    #loss = loss_fn(outputs0, labels0) # DOES WORK
    #loss = loss_fn(outputs1, labels1) # DOES WORK
    loss.backward()
    optimizer.step()

I’m getting the following error message in the backward pass of the loss function (loss.backward()):

RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:1

Could somebody help me to clear out how to solve this issue? Note that labels0 and outputs0 are in GPU 'cuda:0', whilst labels1 and outputs1 are in GPU 'cuda:1'. Is this the way at all to accomplish this task? or do there are better practices? I’ve a big model on 2+ GPUs and would like also to compute intermediate losses

You would have to push the calculated losses to the same device before summing them.
This should work:

loss = loss_fn(outputs1, labels1) + loss_fn(outputs0, labels0).to('cuda:1')

Great, thank you, it worked out! which GPU device I use for backprop does not affect the results/backprop itself, right? Especially, because the whole composed model is: cnn1(cnn0(x)), i.e. the outermost cnn1 is in device cuda:1 and the innermost cnn0 is in device cuda:0. So, if I want a proper backward flow, should I put all the calculated losses in the device associated to the outermost cnn (cuda:1)? The following worked as well:

loss = loss_fn(outputs1, labels1).to('cuda:0') + loss_fn(outputs0, labels0)

The to() operation is differentiable and Autograd will backpropagate “through” this operation to the corresponding devices. I.e. the gradient calculation for the model on GPU0 will be performed on GPU0 and the same applies for GPU1.
Your approach would thus also work and I wouldn’t expect any difference in the performance etc.

1 Like