How to create multi process training with PyTorch without back-propagation and with no_grad()?

I am new to PyTorch and trying to understand how I should code multi-process with no_grad(). Below is the requirement.

I am working on converting a c++ program which was used in a paper to convert nodes to vectors, so it can be used for ML tasks. I basically initialize the weight vector of size V x d (V - number of nodes, d - embedding size) with values between -0.5 to 0.5. Then we sample nodes as u, v, and neg_v, where u & v has an edge and u & neg_v does not have an edge.

The calculation selects the weight(embedding) of respective nodes and uses the below logic for calculating scores.
if we sample u, v:
score = -bias
score += weight[u] * weight[v]
score = (1 - fast_Sigmoid(score)) * lr
weight[u] += weight[v] * score
weight[v] += weight[u] * score

if we sample u, neg_v:
score = -bias
score += weight[u] * weight[neg_v]
score = - fast_Sigmoid(score) * lr
weight[u] += weight[neg_v] * score
weight[neg_v] += weight[u] * score

since all this are performed in a single go and we want to have faster calculation, we came across with torch.no_grad() (" Disabling gradient calculation is useful for inference when you are sure that you will not call backward(). it will reduce memory consumption") and we are using no_grad(). How to multiple-process this training along with memory sharing. Looking for some details about it!!.


Is there a specific reason you want to go with multiprocessing and not multi threading?

If you really need multiprocessing, you don’t need to do anything special here to avoid gradients as nothing requires_grad anyways. So no graph will be created.


Thank you for the reply. I am trying to code for multi gpus or gpu, cpu mix. I am sharing memory with this multi cores, so I am looking to get best standards. The requirement is more like using multi process and multi threads. For instance, if the graph is big I want to run the calculation is multi gpus or gpu, cpu mix to calculate it faster with multi threads. I saw an example implementation
I am trying to follow the same idea, but i just want to make sure i do it right!!