CPU Parallelism for Cumulative Loops

Hi, I don’t have any practical experience parallelizing but I was hoping to parallelize the training of my model over multiple CPU cores in a specific way. The part I want to parallelize over is a for loop that contains a cumulative sum of the error. I thought that this was the most effective way to parallelize my code because my model is supposed to predict various high-dimensional quantities from high-dimensional inputs with the inputs and outputs having variable shapes. The code for the error gradient calculation of a single optimizer step would look something like

error = 0
for training_sample in training_samples:
    predicted_a, predicted_b, predicted_c = model.forward(training_sample.input)
    error += torch.mean((predicted_a - training_sample.expected_a).pow(2)) + 
             torch.mean((predicted_b - training_sample.expected_b).pow(2)) + ...
error.backward()

For large sets of training samples, I got the impression that it was possible to parallelize this by having each CPU core work on calculating the errors due to a fraction of the training samples before adding them all up to calculate the gradient. Is this possible and advisable as a way to parallelize my problem, and how should I start to go about it? (I put this under the category of “distributed” as it seemed related but I couldn’t find a neat example that covers what I want to achieve)

Thank you!

Aren’t you bottlenecked by batch size (memory-wise) anyway?

By my estimations about 16 cores from a cluster I have access to is sufficient for the memory requirements. However, I’d rather not request for 16 cores just for the memory - might as well parallelize the training to make the most of the cores, hence the question. Otherwise, I guess I could just train in mini-batches, one mini-batch at a time without parallelism with 2 cores to avoid wasting resources.

This use case doesn’t exactly need distributed training, but you can try torch.jit.fork which is the main way to do task-level parallelism in PyTorch: torch.jit.fork — PyTorch 1.10.0 documentation.

Cool, thanks! Let me try that.