Hi I’m trying to do a learning a certain way and I need a big batch sufficiently big for it.
Indeed I need to compute a mean vector that converges weekly in O(1/sqrt(n)) (cf. CLT).

I have a powerful machine to use: 8 RTXA6000

When I run my model with a certain batch n it will share n/8 sub-batch on each GPU and then compute 8 different vectors. I’d like to compute only one vector to optimize the weak 1/sqrt(n)-convergence I deal with is it possible to use a trick to do that?

@DidierDeschamps thanks for posting, sorry I don’t quite get your problem, is this a pytorch distributed framework related issue? or an algorithm question?

@wanchaol Pytorch automatically splits my batch across the 8 GPUs and compute 8 Losses with 8 backpropagations. Is there a way to prevent this and compute only 1 Loss with 1 backpropagation?

eg. if my batch is n=80 pytorch will split n/8=10 sub batch on each GPU and compute 8 Losses, unfortunately a loss computed on 10 datapoint is less consistent than a loss computed on 80 datapoint, that’s why I’d like to avoid this to happen.