Getting number of samples in distributed data loader

Is there any easy way to access number of samples that are returned by specific process’ dataloader (which is distributed for multi-gpu training)?

I am training model using torch.dist with multiple GPUs and need to get number of examples in dataloader (that is distributed using torch.utils.data.distributed.DistributedSampler) per rank. I can’t use len(dataloader) because I will get number of batches per rank and when I use len(dataloader.dataset) I will get size of whole dataset (so number of examples that are in total among all ranks).

Can you do something like:

sampler = DistributedSampler(dataset)
num_samples = len(sampler) // dist.get_world_size()