When using DataParallel to wrap my module, do I need to do anything to also parallelize the loss functions?
For example, let’s say that I have large batch size and large output tensors to compute MSE against a target. This operation would benefit from splitting the batch across multiple GPUs, but I’m not sure if the following code does that:
model = MyModule()
model = nn.parallel.DataParallel(model, device_ids=range(args.number_gpus))
output = model(data)
criterion = nn.MSELoss()
loss = criterion(output, target)
This is a simplification based on imagenet example.
In my case, I have a much bigger custom loss module that includes some calls to a VGG network to estimate perceptual loss, and I’m not sure if I am maximizing performance. I tried computing loss as part of the forward function in
MyModule, but this led to recursion errors during the backward step.
you can wrap the loss function inside a DataParallel too if you’d like.
Would that result in one unnecessary scatter/gather?
it would. if you’re worried about that you can put your DataParallel around your model + loss function.
But depending on how many parameters you have in your fully connected layer, it might not work out in terms of speed.
Hi, If i wrap the loss within forward, let’s say I run on 2 gpus, the forward function will return with [loss1, loss2], and should I sum over the loss1 + loss2 and then backward?
or loss1.backward(), loss2.backward()?
the forward function will receive
output = torch.cat([loss1, loss2]), so you can do
If output is [loss1, loss2], can I get the final loss as output.sum() ? And then do loss.backward().
Why wouldn’t you just get the mean like most loss functions do on regular batches?
loss = criterion(output, target).mean()
seems to work fine
What does this do, could you please elaborate?
we are passing a gradient of ones to the backward. Usually, if it’s a scalar output
loss, and you do
loss.backward(), it’s implied that it’s
loss.backward(torch.ones(1)). Because, in this case the loss is actually two elements,
output.backward() will give an error asking for gradients.
I would have the same question as the guys before,
Can I get the final loss by output.sum() and then do loss.backward()? (I saw some blog posts doing that way.)
Is that different from what you suggested here?
Same question here, did you find any difference between using sum() and loss.backward(torch.ones(2))?