How to Get Each Model Replica's Gradient


After I call loss.backward(), PyTorch can auto calculate the average gradients of all model replicas(towers) when I use torch.nn.DataParallel(). Is there any way to get each model replica(tower)'s gradient?

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)
>>> loss = criterion(output, label_var)
>>> loss.backward()


Thank you all, I got it by specialized the Gather function.

Btw, I wonder why replicate function will be iterative call when DataParallel Module forward rather than call once when init?

Sorry, I made a mistake. Actually I solve it by specialized the Broadcast’s backward function(which is a reduce operation).


I also met with this problem. I thought it is the Scatter::backward() gathers the grads until this post remind me it actually is Broadcast::backward().
Thanks very much:)