Is there a general rule to determine batch size per process?
Yes, if possible, it’s better that the data is evenly distributed on different processes, otherwise the process with lighter workload will have to frequently waiting for the stragglers, causing unnecessary slowdown.
Compare to that, if you are using DistributedDataParallel
(DDP) a more important thing is that all processes must have the same number of forward/backward iterations, otherwise collective communications in DDP backward would hang.
how could one determine the local batch size?
See discussion here: Should we split batch_size according to ngpu_per_node when DistributedDataparallel