Update parameters when using multiprocessing

I’m trying to implement a3c paper myself and learn how to use the package torch.multiprocessing.
What I’m wondering is that when digging in some examples, I saw that we can use simply update the model using optim.step().
However, is it safe to do this? Are all parameters updated simultaneously? If not, we might accidentally use the model to forward while only updating a partial params.

As long as you have followed proper model structure:

  • your model inherits from the torch.model base class
  • your model overrides and implements the forward() function

You will be able to just call optim.step() to update all the parameters.

I dont think so b/c as it’s mentioned here, there’s no global semaphore or mutex, etc to prevent other processes from forwarding the model while we’re trying to update the parameters.

import torch as t
import torch.multiprocessing as mp
from torch import nn
import time

class mymodel(nn.Module):
    def __init__(self):
        super(mymodel, self).__init__()
        self.a = t.tensor(0)
        self.b = t.tensor(0)
    def forward(self, X):
        print('Start forwarding')
        print('Current param ', self.a, self.b)
        y = self.a*X[0]
        print('Finish adding a')
        y = y + self.b*X[1]
        print('Finish adding b')
        print('final param ', self.a, self.b)
        return y.sum()

def train(model, idx):
    # Construct data_loader, optimizer, etc.
    if idx == 1:
        print('finish changing b')
        print('finish changing a')
        print('sum', model(t.tensor([1,1])))

model = mymodel()
# print(list(model.parameters()))
num_processes = 2

processes = []
for idx in range(num_processes):
    p = mp.Process(target=train, args=(model, idx))
#     print('kenh14')
for p in processes:
# print(model.a)

Here’s the minimal snippet to reproduce the error I’m mentioning.
We can actually change a part of parameters while forwarding the network.

Can anyone suggest how to fix this?

I think this snippet implies that the forwarding method of nn.modules is not thread-safe.

Can anyone help me, please?