Simultaneously change parameters on all the GPUs

during training the network on 8 GPUs in parallel, I am going to manually change the parameters in the network by the following code,

for param in model.parameters(): parameter)

I am wondering if this would change all the parameters on different GPU or just GPU:0?

It depends what you’re using for parallelism.

If you use nn.DataParallel you should be able to do this, as the model is replicated to the other GPUs in every iteration. This means you only need to modify the parameters of the root module. This is also where you’d run the optimizer, for example.

The model is replicated to the other GPUs in every iteration means the state_dicts are copied to other GPUs every iteration? So the mean and var in BN are also copied to other GPUs from the GPU:0?

Is there any document explain this process elaborately? I am really curious about the parallel mechanism utilized in PyTorch, for I always conduct experiments on multi-gpu environment.

Yes, that’s correct. The documentation covers this (the replication bit), see torch.nn.DataParallel. Note that this is not how the distributed version works. There, every process runs forward/backward/optimizer against a single copy of the model, so its parameters are equivalent already. Not by replicating the values, but by executing the exact same optimizer step.

Thanks a lot! I know what you mean. So, in summarize, the multi-gpu environment works like following:

  1. Scatter the model and the state dict from GPU:0 to all the GPUs.
  2. Split the data, and seperately forward them on different GPUs.
  3. Gather output from GPUs to GPU:0
  4. Calculate Loss by using outputs and targets on GPU:0
  5. Backward Loss to GPUs and seperately calculate gradients
  6. Gather gradients from GPUs to GPU:0
  7. Update parameters on GPU:0
  8. GoTo Step1.

So, the only thing that not fully synchronized is the mean and var of BN, because it does not gather to GPU:0 during backward. All the other parameters are fully synchronized because of the gather-scatter mechanism.