Hey @annisat, that looks good to me. One way to speed it up a bit is to run the dist.all_reduce at the beginning of the loop and set async_op=True. Then only wait for it when you need the result. In this way, the comm and the forward/backward/opt.step computation can overlap. Please see the code in the following thread: