After I call loss.backward(), PyTorch can auto calculate the average gradients of all model replicas(towers) when I use torch.nn.DataParallel(). Is there any way to get each model replica(tower)'s gradient?
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)
>>> loss = criterion(output, label_var)
Thank you all, I got it by specialized the Gather function.
Btw, I wonder why replicate function will be iterative call when DataParallel Module forward rather than call once when init? https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py#L67
Sorry, I made a mistake. Actually I solve it by specialized the Broadcast’s backward function(which is a reduce operation).
I also met with this problem. I thought it is the Scatter::backward() gathers the grads until this post remind me it actually is Broadcast::backward().
Thanks very much:)