I am trying to implement parallelization as follows but not sure if it is possible.
For example, train data with multiple processes (CPU cores). Have each process deal with independent batches. Instead of taking the optimization step independently for each batch, I want to gather loss and gradient from all processes and only take optimization step on Process 0.
Is that possible to do that with torch.distributed package?