Different batch sizes in DistributedDataParallel


I want to know if there is any method to use different batch sizes across processes in distributedDataParallel? If yes, to my knowledge, gradients from different processes should be weighted by their batch sizes. So can PyTorch intelligently handle this issue?

Thanks for any comments in advance.

This comment suggest different approaches using hooks.

1 Like

thank you so much, this is what I am finding exactly.