Custom Loss when using DDP

Hi,

I am trying to implement custom loss using nn.Module. I have to train the model using multiple GPUs. My forward method of the custom loss takes the (labels, logits) as input. My question is, do I need to gather and concatenate the inputs i.e labels and as well as logits inside the forward method before actually calculating the loss?

Below is the basic skeleton of my custom loss.

def multigpu_gather(input, rank, world_size):
“”"
Gather all the tensors from all the processes and concatenate.
“”"
if world_size == 1:
return input
gather_list = [torch.empty_like(input) for _ in range(world_size)]
dist.all_gather(gather_list, input)
gather_list[rank] = input
output = torch.cat(gather_list, dim=0).contiguous()
return output

class CustomLoss(nn.Module):
def init(self, ignore_index, *args, **kwargs):
super().init(*args, **kwargs)
self.ignore_index = ignore_index

    # Custom Losses
    self.loss1 = Loss1()
self.loss2 = Loss2()
	
    self.world_size = dist.get_world_size()

def forward(self, logits, labels, dynamic_w2):
    # Question: DO I need to gather in multi GPU setting?
rank = dist.get_rank()
    labels = multigpu_gather(labels, rank, self.world_size)
    logits = multigpu_gather(logits, rank, self.world_size)

    # Loss1
    loss = self.loss1(logits, labels)   # another nn.Module class

    # Loss 2
    loss2 = self.loss2(logits, labels)  # another nn.Module class
	
    final_loss = loss1 + dynamic_w2 * loss2

    return final_loss.mean()

And, in the training loop, I am calling loss.backward().

Could you please kindly help me understand this? Thanks in advance.

Best,
Bhavana

If you’re using DDP, it takes care of everything.
The following code should be scalable to any number of GPUs.

def train_one_step():
    opt.zero_grad()
    
    x, y = get_next_batch()
    y1 = model(x)
    
    loss = TheLossFunction(y1, y)
    loss.backward()

    opt.step()

You do not need to gather/concat labels. The loss is calculated on every GPU and gradients are gathered, averaged and then backpropagated on all GPUs by DDP.

If you want to log/print loss/metrics, you may need to bring them to the master GPU/node, and that’s when you need to gather them from all GPUs.

if ddp:
    dist.all_reduce(loss, op=dist.ReduceOp.AVG)

I find using frameworks like Lightling helpful.

1 Like

Thanks for the clarification.