Dataparallel for costomized loss

Hi, I would like to parallelize the loss function computation. I customize the loss function by wraping it as a module. In the last line of the module, I need to compute the mean of the loss for the whole batch:

loss = torch.mean(loss_per_sample)

It seems that this line makes the data parallel difficult. Before that, the loss for each data sample can be computed in parallel on different GPUs.

How to make the overall loss function computed in parallel? Thanks!