Multiprocessing on pytorch

Hi everyone,
I have the following problem. Immagine that for whatever reason you want to divide the batch used to perform a step in smaller ones; this could happen, for example, if you want to test your model on a batch size bigger than your memory capacity so that you end up with a micro batching approach.
In my case I using microbatching for a different reason; I want to divide the gradient associated to each class.
To perform micro batching I define a dataloader for each and, at each step, I perform a for loop selecting the batch of one of them, calculating the associated gradient, copying it on a temp variable.
When all the classes have propagated one of their batches the for loop end, I update the weights and pass to the next step reitereting the same procedure.

            for key in set(NetInstance.TrainDL): # NetInstance is a class where TrainDL (a dict of dataloaders) is defined
                #for each class we select a single batch from each class dataloader and repeat the above procedure)
                try:
                    img, lab = next(ClassesIterables[key])
                except StopIteration:
                    ClassesIterables[key] = iter(NetInstance.TrainDL[key]) #when we finished the element of the dataset we reshouflle and restart with the new sequence
                    img, lab = next(ClassesIterables[key])
                img = img.double()
                    
                #load data on device
                img = img.to(device)
                lab = lab.to(device)  
              
              #...
              #propagate batch along the Net, backprop to get gradient
              # store gradient on a temp var
              

Since each iteration in the above for loop is independent from the others I would like to parallelize the iterations using the multiprocessing module.
To do so I need that some variables are shared between processes while other have to be divided for each process, namely:

  • the temp variable where I store the gradient I want that is shared between the processes

  • on the other hand I don’t want that the gradient calculated in one process interact in any way with the one of a different process (after all I performed cthe micro batching exactly to divide the different gradient contribution).
    I was planning then the following approach:

  • define different processes at each step, each one take care of a single iteration of the above for loop; this shoud be automatically create an istance of the variables inside te block (so also of the network and the computed grad)

  • using multiprocessing.shared_memory to store the variable shared by the different processes (e.g. the temp grad copy)

My questions are:

  • is this a good approach to reach my goal? is there a more easy way?
  • are there some important caveat I should have in mind following this way?
  • branching the main process in many processes, each one processing a small batch could create a problem in terms of required memory (as each process work on an instance copy of the model/grad)?

[EDIT]
I found this question:
Gradient disconnected after Multiprocessing pool(starmap)

which partially solved my doubts. In particular, from what I read there:

  • putting torch.set_grad_enabled(False) before branching the processes will automatically create an independent gradient for each of them; if instead a grad tensor is associated to the net it will be shared between processes. Is it right?