Hi,
I am trying to implement custom loss using nn.Module. I have to train the model using multiple GPUs. My forward method of the custom loss takes the (labels, logits) as input. My question is, do I need to gather and concatenate the inputs i.e labels and as well as logits inside the forward method before actually calculating the loss?
Below is the basic skeleton of my custom loss.
def multigpu_gather(input, rank, world_size):
“”"
Gather all the tensors from all the processes and concatenate.
“”"
if world_size == 1:
return input
gather_list = [torch.empty_like(input) for _ in range(world_size)]
dist.all_gather(gather_list, input)
gather_list[rank] = input
output = torch.cat(gather_list, dim=0).contiguous()
return output
class CustomLoss(nn.Module):
def init(self, ignore_index, *args, **kwargs):
super().init(*args, **kwargs)
self.ignore_index = ignore_index
# Custom Losses
self.loss1 = Loss1()
self.loss2 = Loss2()
self.world_size = dist.get_world_size()
def forward(self, logits, labels, dynamic_w2):
# Question: DO I need to gather in multi GPU setting?
rank = dist.get_rank()
labels = multigpu_gather(labels, rank, self.world_size)
logits = multigpu_gather(logits, rank, self.world_size)
# Loss1
loss = self.loss1(logits, labels) # another nn.Module class
# Loss 2
loss2 = self.loss2(logits, labels) # another nn.Module class
final_loss = loss1 + dynamic_w2 * loss2
return final_loss.mean()
And, in the training loop, I am calling loss.backward().
Could you please kindly help me understand this? Thanks in advance.
Best,
Bhavana