Statistics tracking with distributed data parallel

I am wondering if there is a canonical way to track statistics when using distributed data parallel? For example, say I wanted to track the losses across all of the distributed instances, how can I gather the losses in one place given that they are all running in separate processes?

The only way I can think to do it would be to write them all to file along with information about the epoch/iteration, etc. and then combine them later.

Is there a better way to do this?

You can use the gather primitive from ProcessGroup to do this: Distributed communication package - torch.distributed — PyTorch 1.13 documentation

Thanks, this is exactly what I was looking for.