After I call loss.backward(), PyTorch can auto calculate the average gradients of all model replicas(towers) when I use torch.nn.DataParallel(). Is there any way to get each model replica(tower)'s gradient?
Example::
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)
>>> loss = criterion(output, label_var)
>>> loss.backward()
"""
I also met with this problem. I thought it is the Scatter::backward() gathers the grads until this post remind me it actually is Broadcast::backward().
Thanks very much:)
I am interested in something similar. In particular, I am interested in accessing the replicated model parameters. Would you be able to provide a code snippet to show how you were able to access the replicated model gradients?