DDP Hanging on when some iter have no GT for loss.backward()

When I use the DDP for training, some iter may have no ground truth for compute the loss. Following the method proposed here DDP hanging up, I just skip them in the model.forward(x) and use a loss = torch.zeros(1, required_grad=True) and loss.backward(). while also set find_unused_parameters=True.
In his setting, only part of his network was not updated by loss.backward(). But in my setting, I use the some preprocess method in the forward() to compute whether that is a useful feature after passing it from backbone, If there is no feature, I just use the code like this, And will not continue for the following networks. It means, in some iters of some epochs, there are no parameters to be updated.

def forward(x):
feat = self.backbone(x)
pairs, nopair = preprocess(feat, x)
if nopair: # ret here !
output_dict = {}
loss_dict.update({‘total_loss’: torch.zeros(1, requires_grad=True)})
return output_dict, loss_dict
out = self.main_network(feat, x, pairs)
output_dict.update({‘outputs’: out })
loss_dict.update({‘total_loss’: self.compute_loss(out, ground_truth)})
return output_dict, loss_dict

out, loss = model(x)
loss = loss{‘total_loss’}

I’m looking forward to a solution! Thank you all.

1 Like

I also met this issue, have you found the solution to it?